shuffler.hpp
1 
5 #pragma once
6 
7 #include <atomic>
8 #include <chrono>
9 #include <functional>
10 #include <memory>
11 #include <mutex>
12 #include <optional>
13 #include <span>
14 #include <unordered_map>
15 #include <vector>
16 
17 #include <rapidsmpf/buffer/packed_data.hpp>
18 #include <rapidsmpf/buffer/resource.hpp>
19 #include <rapidsmpf/communicator/communicator.hpp>
20 #include <rapidsmpf/error.hpp>
21 #include <rapidsmpf/nvtx.hpp>
22 #include <rapidsmpf/progress_thread.hpp>
23 #include <rapidsmpf/shuffler/chunk.hpp>
24 #include <rapidsmpf/shuffler/finish_counter.hpp>
25 #include <rapidsmpf/shuffler/postbox.hpp>
26 #include <rapidsmpf/statistics.hpp>
27 #include <rapidsmpf/utils.hpp>
28 
29 
30 class ShuffleInsertGroupedTest;
31 
38 namespace rapidsmpf::shuffler {
39 
47 class Shuffler {
48  friend class ::ShuffleInsertGroupedTest;
49 
50  public:
55  using PartitionOwner = std::function<Rank(std::shared_ptr<Communicator>, PartID)>;
56 
64  static Rank round_robin(std::shared_ptr<Communicator> const& comm, PartID pid) {
65  return static_cast<Rank>(pid % static_cast<PartID>(comm->nranks()));
66  }
67 
76  static std::vector<PartID> local_partitions(
77  std::shared_ptr<Communicator> const& comm,
80  );
81 
84 
103  std::shared_ptr<Communicator> comm,
104  std::shared_ptr<ProgressThread> progress_thread,
105  OpID op_id,
107  BufferResource* br,
108  FinishedCallback&& finished_callback,
109  std::shared_ptr<Statistics> statistics = Statistics::disabled(),
111  );
112 
130  std::shared_ptr<Communicator> comm,
131  std::shared_ptr<ProgressThread> progress_thread,
132  OpID op_id,
134  BufferResource* br,
135  std::shared_ptr<Statistics> statistics = Statistics::disabled(),
137  )
138  : Shuffler(
139  comm,
140  progress_thread,
141  op_id,
143  br,
144  nullptr,
145  statistics,
147  ) {}
148 
149  ~Shuffler();
150 
151  Shuffler(Shuffler const&) = delete;
152  Shuffler& operator=(Shuffler const&) = delete;
153 
159  void shutdown();
160 
167  void concat_insert(std::unordered_map<PartID, PackedData>&& chunks);
168 
174  void insert(std::unordered_map<PartID, PackedData>&& chunks);
175 
184 
190  void insert_finished(std::vector<PartID>&& pids);
191 
204  [[nodiscard]] std::vector<PackedData> extract(PartID pid);
205 
211  [[nodiscard]] bool finished() const;
212 
219  [[nodiscard]] bool is_finished(PartID pid) const;
220 
230  PartID wait_any(std::optional<std::chrono::milliseconds> timeout = {});
231 
240  void wait_on(PartID pid, std::optional<std::chrono::milliseconds> timeout = {});
241 
258  std::size_t spill(std::optional<std::size_t> amount = std::nullopt);
259 
264  [[nodiscard]] std::string str() const;
265 
271  [[nodiscard]] std::span<PartID const> local_partitions() const;
272 
276  static constexpr int chunk_id_counter_bits = 38;
277 
281  static constexpr uint64_t counter_mask = (uint64_t(1) << chunk_id_counter_bits) - 1;
282 
288  static constexpr uint64_t extract_counter(detail::ChunkID cid) {
289  return cid & counter_mask;
290  }
291 
297  static constexpr Rank extract_rank(detail::ChunkID cid) {
298  return static_cast<Rank>(cid >> chunk_id_counter_bits);
299  }
300 
306  static constexpr std::pair<Rank, uint64_t> extract_info(detail::ChunkID cid) {
307  return std::make_pair(extract_rank(cid), extract_counter(cid));
308  }
309 
310  private:
316  void insert(detail::Chunk&& chunk);
317 
323  void insert_into_ready_postbox(detail::Chunk&& chunk);
324 
326  [[nodiscard]] detail::ChunkID get_new_cid();
327 
336  [[nodiscard]] detail::Chunk create_chunk(PartID pid, PackedData&& packed_data);
337 
338  public:
341 
342  private:
343  BufferResource* br_;
344  std::atomic<bool> active_{true};
345  detail::PostBox<Rank> outgoing_postbox_;
347  detail::PostBox<PartID> ready_postbox_;
349 
350  std::shared_ptr<Communicator> comm_;
351  std::shared_ptr<ProgressThread> progress_thread_;
352  ProgressThread::FunctionID progress_thread_function_id_;
353  OpID const op_id_;
354 
355  SpillManager::SpillFunctionID spill_function_id_;
356 
357  std::vector<PartID> const local_partitions_;
358 
359  detail::FinishCounter finish_counter_;
360  std::unordered_map<PartID, detail::ChunkID> outbound_chunk_counter_;
361  mutable std::mutex outbound_chunk_counter_mutex_;
362 
363  // We protect ready_postbox extraction to avoid returning a chunk that is in the
364  // process of being spilled by `Shuffler::spill`.
365  mutable std::mutex ready_postbox_spilling_mutex_;
366 
367  std::atomic<detail::ChunkID> chunk_id_counter_{0};
368 
369  std::shared_ptr<Statistics> statistics_;
370 
371  class Progress;
372 };
373 
383 inline std::ostream& operator<<(std::ostream& os, Shuffler const& obj) {
384  os << obj.str();
385  return os;
386 }
387 
388 } // namespace rapidsmpf::shuffler
Class managing buffer resources.
Definition: resource.hpp:133
std::size_t SpillFunctionID
Represents a unique identifier for a registered spill function.
static std::shared_ptr< Statistics > disabled()
Returns a shared pointer to a disabled (no-op) Statistics instance.
Shuffle service for cuDF tables.
Definition: shuffler.hpp:47
static constexpr uint64_t counter_mask
The mask for the counter in a chunk ID.
Definition: shuffler.hpp:281
void insert_finished(PartID pid)
Insert a finish mark for a partition.
Shuffler(std::shared_ptr< Communicator > comm, std::shared_ptr< ProgressThread > progress_thread, OpID op_id, PartID total_num_partitions, BufferResource *br, FinishedCallback &&finished_callback, std::shared_ptr< Statistics > statistics=Statistics::disabled(), PartitionOwner partition_owner=round_robin)
Construct a new shuffler for a single shuffle.
PartID wait_any(std::optional< std::chrono::milliseconds > timeout={})
Wait for any partition to finish.
void shutdown()
Shutdown the shuffle, blocking until all inflight communication is done.
void insert(std::unordered_map< PartID, PackedData > &&chunks)
Insert a bunch of packed (serialized) chunks into the shuffle.
std::size_t spill(std::optional< std::size_t > amount=std::nullopt)
Spills data to device if necessary.
static constexpr int chunk_id_counter_bits
The number of bits used to store the counter in a chunk ID.
Definition: shuffler.hpp:276
Shuffler(std::shared_ptr< Communicator > comm, std::shared_ptr< ProgressThread > progress_thread, OpID op_id, PartID total_num_partitions, BufferResource *br, std::shared_ptr< Statistics > statistics=Statistics::disabled(), PartitionOwner partition_owner=round_robin)
Construct a new shuffler for a single shuffle.
Definition: shuffler.hpp:129
static std::vector< PartID > local_partitions(std::shared_ptr< Communicator > const &comm, PartID total_num_partitions, PartitionOwner partition_owner)
Returns the local partition IDs owned by the current node.
void concat_insert(std::unordered_map< PartID, PackedData > &&chunks)
Insert a map of packed data, grouping them by destination rank, and concatenating into a single chunk...
void wait_on(PartID pid, std::optional< std::chrono::milliseconds > timeout={})
Wait for a specific partition to finish (blocking).
static Rank round_robin(std::shared_ptr< Communicator > const &comm, PartID pid)
A PartitionOwner that distribute the partition using round robin.
Definition: shuffler.hpp:64
void insert_finished(std::vector< PartID > &&pids)
Insert a finish mark for a list of partitions.
PartID const total_num_partitions
Total number of partition in the shuffle.
Definition: shuffler.hpp:339
std::string str() const
Returns a description of this instance.
static constexpr std::pair< Rank, uint64_t > extract_info(detail::ChunkID cid)
Extract the rank and counter from a chunk ID.
Definition: shuffler.hpp:306
std::function< Rank(std::shared_ptr< Communicator >, PartID)> PartitionOwner
Function that given a Communicator and a PartID, returns the rapidsmpf::Rank of the owning node.
Definition: shuffler.hpp:55
static constexpr Rank extract_rank(detail::ChunkID cid)
Extract the rank from a chunk ID.
Definition: shuffler.hpp:297
bool finished() const
Check if all partitions are finished.
detail::FinishCounter::FinishedCallback FinishedCallback
Callback function type called when a partition is finished.
Definition: shuffler.hpp:83
static constexpr uint64_t extract_counter(detail::ChunkID cid)
Extract the counter from a chunk ID.
Definition: shuffler.hpp:288
std::vector< PackedData > extract(PartID pid)
Extract all chunks belonging to the specified partition.
std::span< PartID const > local_partitions() const
Returns the local partition IDs owned by the shuffler`.
bool is_finished(PartID pid) const
Check if a partition is finished.
PartitionOwner const partition_owner
Function to determine partition ownership.
Definition: shuffler.hpp:340
Chunk with multiple messages. This class contains two buffers for concatenated metadata and data.
Definition: chunk.hpp:58
Helper to tally the finish status of a shuffle.
std::function< void(PartID)> FinishedCallback
Callback function type called when a partition is finished.
std::uint64_t ChunkID
The globally unique ID of a chunk.
Definition: chunk.hpp:29
Shuffler interfaces.
Definition: chunk.hpp:15
std::uint32_t PartID
Partition ID, which goes from 0 to the total number of partitions.
Definition: chunk.hpp:22
std::ostream & operator<<(std::ostream &os, detail::FinishCounter const &obj)
Overloads the stream insertion operator for the FinishCounter class.
Bag of bytes with metadata suitable for sending over the wire.
Definition: packed_data.hpp:26
The unique ID of a function registered with ProgressThread. Composed of the ProgressThread address an...