shuffler.hpp
1 
5 #pragma once
6 
7 #include <atomic>
8 #include <chrono>
9 #include <functional>
10 #include <memory>
11 #include <mutex>
12 #include <optional>
13 #include <span>
14 #include <unordered_map>
15 #include <vector>
16 
17 #include <rapidsmpf/communicator/communicator.hpp>
18 #include <rapidsmpf/error.hpp>
19 #include <rapidsmpf/memory/buffer_resource.hpp>
20 #include <rapidsmpf/memory/packed_data.hpp>
21 #include <rapidsmpf/nvtx.hpp>
22 #include <rapidsmpf/progress_thread.hpp>
23 #include <rapidsmpf/shuffler/chunk.hpp>
24 #include <rapidsmpf/shuffler/finish_counter.hpp>
25 #include <rapidsmpf/shuffler/postbox.hpp>
26 #include <rapidsmpf/statistics.hpp>
27 #include <rapidsmpf/utils/misc.hpp>
28 
35 namespace rapidsmpf::shuffler {
36 
44 class Shuffler {
45  public:
51  std::function<Rank(std::shared_ptr<Communicator> const&, PartID, PartID)>;
52 
61  static Rank round_robin(
62  std::shared_ptr<Communicator> const& comm,
63  PartID pid,
64  [[maybe_unused]] PartID total_num_partitions
65  ) {
66  return safe_cast<Rank>(pid % safe_cast<PartID>(comm->nranks()));
67  }
68 
80  static Rank contiguous(
81  std::shared_ptr<Communicator> const& comm, PartID pid, PartID total_num_partitions
82  ) {
83  return safe_cast<Rank>(
84  (pid * safe_cast<PartID>(comm->nranks())) / total_num_partitions
85  );
86  }
87 
96  static std::vector<PartID> local_partitions(
97  std::shared_ptr<Communicator> const& comm,
100  );
101 
104 
121  std::shared_ptr<Communicator> comm,
122  OpID op_id,
124  BufferResource* br,
125  FinishedCallback&& finished_callback,
127  );
128 
144  std::shared_ptr<Communicator> comm,
145  OpID op_id,
147  BufferResource* br,
149  )
150  : Shuffler(comm, op_id, total_num_partitions, br, nullptr, partition_owner) {}
151 
152  ~Shuffler();
153 
159  [[nodiscard]] std::shared_ptr<Communicator> const& comm() const noexcept {
160  return comm_;
161  }
162 
163  Shuffler(Shuffler const&) = delete;
164  Shuffler& operator=(Shuffler const&) = delete;
165 
171  void shutdown();
172 
178  void insert(std::unordered_map<PartID, PackedData>&& chunks);
179 
188 
194  void insert_finished(std::vector<PartID>&& pids);
195 
208  [[nodiscard]] std::vector<PackedData> extract(PartID pid);
209 
215  [[nodiscard]] bool finished() const;
216 
223  [[nodiscard]] bool is_finished(PartID pid) const;
224 
234  PartID wait_any(std::optional<std::chrono::milliseconds> timeout = {});
235 
244  void wait_on(PartID pid, std::optional<std::chrono::milliseconds> timeout = {});
245 
259  std::size_t spill(std::optional<std::size_t> amount = std::nullopt);
260 
265  [[nodiscard]] std::string str() const;
266 
272  [[nodiscard]] std::span<PartID const> local_partitions() const;
273 
277  static constexpr int chunk_id_counter_bits = 38;
278 
282  static constexpr std::uint64_t counter_mask =
283  (std::uint64_t{1} << chunk_id_counter_bits) - 1;
284 
290  static constexpr std::uint64_t extract_counter(detail::ChunkID cid) {
291  return cid & counter_mask;
292  }
293 
299  static constexpr Rank extract_rank(detail::ChunkID cid) {
300  return safe_cast<Rank>(cid >> chunk_id_counter_bits);
301  }
302 
308  static constexpr std::pair<Rank, std::uint64_t> extract_info(detail::ChunkID cid) {
309  return std::make_pair(extract_rank(cid), extract_counter(cid));
310  }
311 
312  private:
318  void insert(detail::Chunk&& chunk);
319 
325  void insert_into_ready_postbox(detail::Chunk&& chunk);
326 
328  [[nodiscard]] detail::ChunkID get_new_cid();
329 
338  [[nodiscard]] detail::Chunk create_chunk(PartID pid, PackedData&& packed_data);
339 
340  public:
343 
344  private:
345  BufferResource* br_;
346  std::atomic<bool> active_{true};
347  detail::PostBox<Rank> outgoing_postbox_;
349  detail::PostBox<PartID> ready_postbox_;
351 
352  std::shared_ptr<Communicator> comm_;
353  ProgressThread::FunctionID progress_thread_function_id_;
354  OpID const op_id_;
355 
356  SpillManager::SpillFunctionID spill_function_id_;
357 
358  std::vector<PartID> const local_partitions_;
359 
360  detail::FinishCounter finish_counter_;
361  std::unordered_map<PartID, detail::ChunkID> outbound_chunk_counter_;
362  mutable std::mutex outbound_chunk_counter_mutex_;
363 
364  // We protect ready_postbox extraction to avoid returning a chunk that is in the
365  // process of being spilled by `Shuffler::spill`.
366  mutable std::mutex ready_postbox_spilling_mutex_;
367 
368  std::atomic<detail::ChunkID> chunk_id_counter_{0};
369 
370  std::shared_ptr<Statistics> statistics_;
371 
372  class Progress;
373 };
374 
384 inline std::ostream& operator<<(std::ostream& os, Shuffler const& obj) {
385  os << obj.str();
386  return os;
387 }
388 
389 } // namespace rapidsmpf::shuffler
Class managing buffer resources.
std::size_t SpillFunctionID
Represents a unique identifier for a registered spill function.
Shuffle service for cuDF tables.
Definition: shuffler.hpp:44
Shuffler(std::shared_ptr< Communicator > comm, OpID op_id, PartID total_num_partitions, BufferResource *br, FinishedCallback &&finished_callback, PartitionOwner partition_owner=round_robin)
Construct a new shuffler for a single shuffle.
void insert_finished(PartID pid)
Insert a finish mark for a partition.
static constexpr std::pair< Rank, std::uint64_t > extract_info(detail::ChunkID cid)
Extract the rank and counter from a chunk ID.
Definition: shuffler.hpp:308
PartID wait_any(std::optional< std::chrono::milliseconds > timeout={})
Wait for any partition to finish.
void shutdown()
Shutdown the shuffle, blocking until all inflight communication is done.
static constexpr std::uint64_t extract_counter(detail::ChunkID cid)
Extract the counter from a chunk ID.
Definition: shuffler.hpp:290
void insert(std::unordered_map< PartID, PackedData > &&chunks)
Insert a bunch of packed (serialized) chunks into the shuffle.
std::size_t spill(std::optional< std::size_t > amount=std::nullopt)
Spills data to device if necessary.
static constexpr int chunk_id_counter_bits
The number of bits used to store the counter in a chunk ID.
Definition: shuffler.hpp:277
Shuffler(std::shared_ptr< Communicator > comm, OpID op_id, PartID total_num_partitions, BufferResource *br, PartitionOwner partition_owner=round_robin)
Construct a new shuffler for a single shuffle.
Definition: shuffler.hpp:143
static std::vector< PartID > local_partitions(std::shared_ptr< Communicator > const &comm, PartID total_num_partitions, PartitionOwner partition_owner)
Returns the local partition IDs owned by the current node.
static Rank contiguous(std::shared_ptr< Communicator > const &comm, PartID pid, PartID total_num_partitions)
A PartitionOwner that assigns contiguous partition ID ranges to ranks.
Definition: shuffler.hpp:80
std::function< Rank(std::shared_ptr< Communicator > const &, PartID, PartID)> PartitionOwner
Function that given a Communicator, PartID, and total partition count, returns the rapidsmpf::Rank of...
Definition: shuffler.hpp:51
void wait_on(PartID pid, std::optional< std::chrono::milliseconds > timeout={})
Wait for a specific partition to finish (blocking).
void insert_finished(std::vector< PartID > &&pids)
Insert a finish mark for a list of partitions.
PartID const total_num_partitions
Total number of partition in the shuffle.
Definition: shuffler.hpp:341
static Rank round_robin(std::shared_ptr< Communicator > const &comm, PartID pid, [[maybe_unused]] PartID total_num_partitions)
A PartitionOwner that distributes partitions using round robin assignment.
Definition: shuffler.hpp:61
std::string str() const
Returns a description of this instance.
static constexpr Rank extract_rank(detail::ChunkID cid)
Extract the rank from a chunk ID.
Definition: shuffler.hpp:299
bool finished() const
Check if all partitions are finished.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this Shuffler.
Definition: shuffler.hpp:159
detail::FinishCounter::FinishedCallback FinishedCallback
Callback function type called when a partition is finished.
Definition: shuffler.hpp:103
std::vector< PackedData > extract(PartID pid)
Extract all chunks belonging to the specified partition.
static constexpr std::uint64_t counter_mask
The mask for the counter in a chunk ID.
Definition: shuffler.hpp:282
std::span< PartID const > local_partitions() const
Returns the local partition IDs owned by the shuffler`.
bool is_finished(PartID pid) const
Check if a partition is finished.
PartitionOwner const partition_owner
Function to determine partition ownership.
Definition: shuffler.hpp:342
A partition chunk representing either a data message or a control message.
Definition: chunk.hpp:58
Helper to tally the finish status of a shuffle.
std::function< void(PartID)> FinishedCallback
Callback function type called when a partition is finished.
std::uint64_t ChunkID
The globally unique ID of a chunk.
Definition: chunk.hpp:29
Shuffler interfaces.
Definition: chunk.hpp:15
std::uint32_t PartID
Partition ID, which goes from 0 to the total number of partitions.
Definition: chunk.hpp:22
std::ostream & operator<<(std::ostream &os, detail::FinishCounter const &obj)
Overloads the stream insertion operator for the FinishCounter class.
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
Bag of bytes with metadata suitable for sending over the wire.
Definition: packed_data.hpp:26
The unique ID of a function registered with ProgressThread. Composed of the ProgressThread address an...