9 #include <condition_variable>
15 #include <unordered_map>
18 #include <rapidsmpf/communicator/communicator.hpp>
19 #include <rapidsmpf/communicator/metadata_payload_exchange/tag.hpp>
20 #include <rapidsmpf/error.hpp>
21 #include <rapidsmpf/memory/buffer_resource.hpp>
22 #include <rapidsmpf/memory/packed_data.hpp>
23 #include <rapidsmpf/nvtx.hpp>
24 #include <rapidsmpf/progress_thread.hpp>
25 #include <rapidsmpf/shuffler/chunk.hpp>
26 #include <rapidsmpf/shuffler/finish_counter.hpp>
27 #include <rapidsmpf/shuffler/postbox.hpp>
28 #include <rapidsmpf/statistics.hpp>
29 #include <rapidsmpf/utils/misc.hpp>
64 std::shared_ptr<Communicator>
const&
comm,
68 return safe_cast<Rank>(pid % safe_cast<PartID>(
comm->nranks()));
85 return safe_cast<Rank>(
99 std::shared_ptr<Communicator>
const&
comm,
133 std::shared_ptr<Communicator>
comm,
139 std::unique_ptr<communicator::MetadataPayloadExchange> mpe =
nullptr
159 std::shared_ptr<Communicator>
comm,
164 std::unique_ptr<communicator::MetadataPayloadExchange> mpe =
nullptr
183 [[nodiscard]] std::shared_ptr<Communicator>
const&
comm() const noexcept {
205 void insert(std::unordered_map<PartID, PackedData>&& chunks);
247 void wait(std::optional<std::chrono::milliseconds> timeout = {});
262 std::size_t
spill(std::optional<std::size_t> amount = std::nullopt);
268 [[nodiscard]] std::string
str()
const;
325 std::atomic<bool> active_{
true};
327 std::atomic<bool> locally_finished_{
false};
330 bool can_extract_{
false};
331 detail::ChunksToSend to_send_;
332 detail::ReceivedChunks received_;
335 std::shared_ptr<Communicator> comm_;
336 std::unique_ptr<communicator::MetadataPayloadExchange> mpe_;
337 ProgressThread::FunctionID progress_thread_function_id_;
341 std::vector<PartID>
const local_partitions_;
343 detail::FinishCounter finish_counter_;
344 std::vector<detail::ChunkID> outbound_chunk_counter_;
345 std::atomic<detail::ChunkID> chunk_id_counter_{0};
347 std::shared_ptr<Statistics> statistics_;
350 mutable std::mutex mutex_;
351 std::condition_variable cv_;
Class managing buffer resources.
std::size_t SpillFunctionID
Represents a unique identifier for a registered spill function.
Shuffle service for all-to-all style communication of partitioned data.
Shuffler(std::shared_ptr< Communicator > comm, OpID op_id, PartID total_num_partitions, BufferResource *br, FinishedCallback &&finished_callback, PartitionOwner partition_owner=round_robin, std::unique_ptr< communicator::MetadataPayloadExchange > mpe=nullptr)
Construct a new shuffler for a single shuffle.
void shutdown()
Shutdown the shuffle, blocking until all inflight communication is done.
void insert(std::unordered_map< PartID, PackedData > &&chunks)
Insert a bunch of packed (serialized) chunks into the shuffle.
std::size_t spill(std::optional< std::size_t > amount=std::nullopt)
Spills data to device if necessary.
static constexpr int chunk_id_counter_bits
The number of bits used to store the counter in a chunk ID.
static std::vector< PartID > local_partitions(std::shared_ptr< Communicator > const &comm, PartID total_num_partitions, PartitionOwner partition_owner)
Returns the local partition IDs owned by the current node.
static Rank contiguous(std::shared_ptr< Communicator > const &comm, PartID pid, PartID total_num_partitions)
A PartitionOwner that assigns contiguous partition ID ranges to ranks.
std::function< Rank(std::shared_ptr< Communicator > const &, PartID, PartID)> PartitionOwner
Function that given a Communicator, PartID, and total partition count, returns the rapidsmpf::Rank of...
void insert_finished()
Signal that no more data will be inserted into the shuffle.
PartID const total_num_partitions
Total number of partition in the shuffle.
static Rank round_robin(std::shared_ptr< Communicator > const &comm, PartID pid, [[maybe_unused]] PartID total_num_partitions)
A PartitionOwner that distributes partitions using round robin assignment.
std::string str() const
Returns a description of this instance.
static constexpr Rank extract_rank(detail::ChunkID cid)
Extract the rank from a chunk ID.
bool finished() const
Check if all partitions are finished.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this Shuffler.
Shuffler(std::shared_ptr< Communicator > comm, OpID op_id, PartID total_num_partitions, BufferResource *br, PartitionOwner partition_owner=round_robin, std::unique_ptr< communicator::MetadataPayloadExchange > mpe=nullptr)
Construct a new shuffler for a single shuffle.
std::vector< PackedData > extract(PartID pid)
Extract all chunks belonging to the specified partition.
std::span< PartID const > local_partitions() const
Returns the local partition IDs owned by the shuffler`.
void wait(std::optional< std::chrono::milliseconds > timeout={})
Wait for all partitions to finish (blocking).
PartitionOwner const partition_owner
Function to determine partition ownership.
std::function< void()> FinishedCallback
Callback function type called when all partitions are finished and data can be extracted.
A partition chunk representing either a data message or a control message.
std::uint64_t ChunkID
The globally unique ID of a chunk.
std::uint32_t PartID
Partition ID, which goes from 0 to the total number of partitions.
std::ostream & operator<<(std::ostream &os, detail::FinishCounter const &obj)
Overloads the stream insertion operator for the FinishCounter class.
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
Bag of bytes with metadata suitable for sending over the wire.