14 #include <unordered_map>
17 #include <rapidsmpf/buffer/packed_data.hpp>
18 #include <rapidsmpf/buffer/resource.hpp>
19 #include <rapidsmpf/communicator/communicator.hpp>
20 #include <rapidsmpf/error.hpp>
21 #include <rapidsmpf/nvtx.hpp>
22 #include <rapidsmpf/progress_thread.hpp>
23 #include <rapidsmpf/shuffler/chunk.hpp>
24 #include <rapidsmpf/shuffler/finish_counter.hpp>
25 #include <rapidsmpf/shuffler/postbox.hpp>
26 #include <rapidsmpf/statistics.hpp>
27 #include <rapidsmpf/utils.hpp>
30 class ShuffleInsertGroupedTest;
48 friend class ::ShuffleInsertGroupedTest;
65 return static_cast<Rank
>(pid %
static_cast<PartID>(comm->nranks()));
77 std::shared_ptr<Communicator>
const& comm,
103 std::shared_ptr<Communicator> comm,
104 std::shared_ptr<ProgressThread> progress_thread,
130 std::shared_ptr<Communicator> comm,
131 std::shared_ptr<ProgressThread> progress_thread,
174 void insert(std::unordered_map<PartID, PackedData>&& chunks);
240 void wait_on(
PartID pid, std::optional<std::chrono::milliseconds> timeout = {});
258 std::size_t
spill(std::optional<std::size_t> amount = std::nullopt);
264 [[nodiscard]] std::string
str()
const;
344 std::atomic<bool> active_{
true};
350 std::shared_ptr<Communicator> comm_;
351 std::shared_ptr<ProgressThread> progress_thread_;
357 std::vector<PartID>
const local_partitions_;
360 std::unordered_map<PartID, detail::ChunkID> outbound_chunk_counter_;
361 mutable std::mutex outbound_chunk_counter_mutex_;
365 mutable std::mutex ready_postbox_spilling_mutex_;
367 std::atomic<detail::ChunkID> chunk_id_counter_{0};
369 std::shared_ptr<Statistics> statistics_;
Class managing buffer resources.
std::size_t SpillFunctionID
Represents a unique identifier for a registered spill function.
static std::shared_ptr< Statistics > disabled()
Returns a shared pointer to a disabled (no-op) Statistics instance.
Shuffle service for cuDF tables.
static constexpr uint64_t counter_mask
The mask for the counter in a chunk ID.
void insert_finished(PartID pid)
Insert a finish mark for a partition.
Shuffler(std::shared_ptr< Communicator > comm, std::shared_ptr< ProgressThread > progress_thread, OpID op_id, PartID total_num_partitions, BufferResource *br, FinishedCallback &&finished_callback, std::shared_ptr< Statistics > statistics=Statistics::disabled(), PartitionOwner partition_owner=round_robin)
Construct a new shuffler for a single shuffle.
PartID wait_any(std::optional< std::chrono::milliseconds > timeout={})
Wait for any partition to finish.
void shutdown()
Shutdown the shuffle, blocking until all inflight communication is done.
void insert(std::unordered_map< PartID, PackedData > &&chunks)
Insert a bunch of packed (serialized) chunks into the shuffle.
std::size_t spill(std::optional< std::size_t > amount=std::nullopt)
Spills data to device if necessary.
static constexpr int chunk_id_counter_bits
The number of bits used to store the counter in a chunk ID.
Shuffler(std::shared_ptr< Communicator > comm, std::shared_ptr< ProgressThread > progress_thread, OpID op_id, PartID total_num_partitions, BufferResource *br, std::shared_ptr< Statistics > statistics=Statistics::disabled(), PartitionOwner partition_owner=round_robin)
Construct a new shuffler for a single shuffle.
static std::vector< PartID > local_partitions(std::shared_ptr< Communicator > const &comm, PartID total_num_partitions, PartitionOwner partition_owner)
Returns the local partition IDs owned by the current node.
void concat_insert(std::unordered_map< PartID, PackedData > &&chunks)
Insert a map of packed data, grouping them by destination rank, and concatenating into a single chunk...
void wait_on(PartID pid, std::optional< std::chrono::milliseconds > timeout={})
Wait for a specific partition to finish (blocking).
static Rank round_robin(std::shared_ptr< Communicator > const &comm, PartID pid)
A PartitionOwner that distribute the partition using round robin.
void insert_finished(std::vector< PartID > &&pids)
Insert a finish mark for a list of partitions.
PartID const total_num_partitions
Total number of partition in the shuffle.
std::string str() const
Returns a description of this instance.
static constexpr std::pair< Rank, uint64_t > extract_info(detail::ChunkID cid)
Extract the rank and counter from a chunk ID.
std::function< Rank(std::shared_ptr< Communicator >, PartID)> PartitionOwner
Function that given a Communicator and a PartID, returns the rapidsmpf::Rank of the owning node.
static constexpr Rank extract_rank(detail::ChunkID cid)
Extract the rank from a chunk ID.
bool finished() const
Check if all partitions are finished.
detail::FinishCounter::FinishedCallback FinishedCallback
Callback function type called when a partition is finished.
static constexpr uint64_t extract_counter(detail::ChunkID cid)
Extract the counter from a chunk ID.
std::vector< PackedData > extract(PartID pid)
Extract all chunks belonging to the specified partition.
std::span< PartID const > local_partitions() const
Returns the local partition IDs owned by the shuffler`.
bool is_finished(PartID pid) const
Check if a partition is finished.
PartitionOwner const partition_owner
Function to determine partition ownership.
Chunk with multiple messages. This class contains two buffers for concatenated metadata and data.
Helper to tally the finish status of a shuffle.
std::function< void(PartID)> FinishedCallback
Callback function type called when a partition is finished.
std::uint64_t ChunkID
The globally unique ID of a chunk.
std::uint32_t PartID
Partition ID, which goes from 0 to the total number of partitions.
std::ostream & operator<<(std::ostream &os, detail::FinishCounter const &obj)
Overloads the stream insertion operator for the FinishCounter class.
Bag of bytes with metadata suitable for sending over the wire.
The unique ID of a function registered with ProgressThread. Composed of the ProgressThread address an...