14 #include <unordered_map>
17 #include <rapidsmpf/communicator/communicator.hpp>
18 #include <rapidsmpf/error.hpp>
19 #include <rapidsmpf/memory/buffer_resource.hpp>
20 #include <rapidsmpf/memory/packed_data.hpp>
21 #include <rapidsmpf/nvtx.hpp>
22 #include <rapidsmpf/progress_thread.hpp>
23 #include <rapidsmpf/shuffler/chunk.hpp>
24 #include <rapidsmpf/shuffler/finish_counter.hpp>
25 #include <rapidsmpf/shuffler/postbox.hpp>
26 #include <rapidsmpf/statistics.hpp>
27 #include <rapidsmpf/utils/misc.hpp>
62 std::shared_ptr<Communicator>
const&
comm,
66 return safe_cast<Rank>(pid % safe_cast<PartID>(
comm->nranks()));
83 return safe_cast<Rank>(
97 std::shared_ptr<Communicator>
const&
comm,
121 std::shared_ptr<Communicator>
comm,
144 std::shared_ptr<Communicator>
comm,
159 [[nodiscard]] std::shared_ptr<Communicator>
const&
comm() const noexcept {
178 void insert(std::unordered_map<PartID, PackedData>&& chunks);
244 void wait_on(
PartID pid, std::optional<std::chrono::milliseconds> timeout = {});
259 std::size_t
spill(std::optional<std::size_t> amount = std::nullopt);
265 [[nodiscard]] std::string
str()
const;
346 std::atomic<bool> active_{
true};
352 std::shared_ptr<Communicator> comm_;
358 std::vector<PartID>
const local_partitions_;
361 std::unordered_map<PartID, detail::ChunkID> outbound_chunk_counter_;
362 mutable std::mutex outbound_chunk_counter_mutex_;
366 mutable std::mutex ready_postbox_spilling_mutex_;
368 std::atomic<detail::ChunkID> chunk_id_counter_{0};
370 std::shared_ptr<Statistics> statistics_;
Class managing buffer resources.
std::size_t SpillFunctionID
Represents a unique identifier for a registered spill function.
Shuffle service for cuDF tables.
Shuffler(std::shared_ptr< Communicator > comm, OpID op_id, PartID total_num_partitions, BufferResource *br, FinishedCallback &&finished_callback, PartitionOwner partition_owner=round_robin)
Construct a new shuffler for a single shuffle.
void insert_finished(PartID pid)
Insert a finish mark for a partition.
static constexpr std::pair< Rank, std::uint64_t > extract_info(detail::ChunkID cid)
Extract the rank and counter from a chunk ID.
PartID wait_any(std::optional< std::chrono::milliseconds > timeout={})
Wait for any partition to finish.
void shutdown()
Shutdown the shuffle, blocking until all inflight communication is done.
static constexpr std::uint64_t extract_counter(detail::ChunkID cid)
Extract the counter from a chunk ID.
void insert(std::unordered_map< PartID, PackedData > &&chunks)
Insert a bunch of packed (serialized) chunks into the shuffle.
std::size_t spill(std::optional< std::size_t > amount=std::nullopt)
Spills data to device if necessary.
static constexpr int chunk_id_counter_bits
The number of bits used to store the counter in a chunk ID.
Shuffler(std::shared_ptr< Communicator > comm, OpID op_id, PartID total_num_partitions, BufferResource *br, PartitionOwner partition_owner=round_robin)
Construct a new shuffler for a single shuffle.
static std::vector< PartID > local_partitions(std::shared_ptr< Communicator > const &comm, PartID total_num_partitions, PartitionOwner partition_owner)
Returns the local partition IDs owned by the current node.
static Rank contiguous(std::shared_ptr< Communicator > const &comm, PartID pid, PartID total_num_partitions)
A PartitionOwner that assigns contiguous partition ID ranges to ranks.
std::function< Rank(std::shared_ptr< Communicator > const &, PartID, PartID)> PartitionOwner
Function that given a Communicator, PartID, and total partition count, returns the rapidsmpf::Rank of...
void wait_on(PartID pid, std::optional< std::chrono::milliseconds > timeout={})
Wait for a specific partition to finish (blocking).
void insert_finished(std::vector< PartID > &&pids)
Insert a finish mark for a list of partitions.
PartID const total_num_partitions
Total number of partition in the shuffle.
static Rank round_robin(std::shared_ptr< Communicator > const &comm, PartID pid, [[maybe_unused]] PartID total_num_partitions)
A PartitionOwner that distributes partitions using round robin assignment.
std::string str() const
Returns a description of this instance.
static constexpr Rank extract_rank(detail::ChunkID cid)
Extract the rank from a chunk ID.
bool finished() const
Check if all partitions are finished.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this Shuffler.
detail::FinishCounter::FinishedCallback FinishedCallback
Callback function type called when a partition is finished.
std::vector< PackedData > extract(PartID pid)
Extract all chunks belonging to the specified partition.
static constexpr std::uint64_t counter_mask
The mask for the counter in a chunk ID.
std::span< PartID const > local_partitions() const
Returns the local partition IDs owned by the shuffler`.
bool is_finished(PartID pid) const
Check if a partition is finished.
PartitionOwner const partition_owner
Function to determine partition ownership.
A partition chunk representing either a data message or a control message.
Helper to tally the finish status of a shuffle.
std::function< void(PartID)> FinishedCallback
Callback function type called when a partition is finished.
std::uint64_t ChunkID
The globally unique ID of a chunk.
std::uint32_t PartID
Partition ID, which goes from 0 to the total number of partitions.
std::ostream & operator<<(std::ostream &os, detail::FinishCounter const &obj)
Overloads the stream insertion operator for the FinishCounter class.
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
Bag of bytes with metadata suitable for sending over the wire.
The unique ID of a function registered with ProgressThread. Composed of the ProgressThread address an...