shuffler.hpp
1 
5 #pragma once
6 
7 #include <atomic>
8 #include <chrono>
9 #include <condition_variable>
10 #include <functional>
11 #include <memory>
12 #include <mutex>
13 #include <optional>
14 #include <span>
15 #include <unordered_map>
16 #include <vector>
17 
18 #include <rapidsmpf/communicator/communicator.hpp>
19 #include <rapidsmpf/communicator/metadata_payload_exchange/tag.hpp>
20 #include <rapidsmpf/error.hpp>
21 #include <rapidsmpf/memory/buffer_resource.hpp>
22 #include <rapidsmpf/memory/packed_data.hpp>
23 #include <rapidsmpf/nvtx.hpp>
24 #include <rapidsmpf/progress_thread.hpp>
25 #include <rapidsmpf/shuffler/chunk.hpp>
26 #include <rapidsmpf/shuffler/finish_counter.hpp>
27 #include <rapidsmpf/shuffler/postbox.hpp>
28 #include <rapidsmpf/statistics.hpp>
29 #include <rapidsmpf/utils/misc.hpp>
30 
37 namespace rapidsmpf::shuffler {
38 
46 class Shuffler {
47  public:
53  std::function<Rank(std::shared_ptr<Communicator> const&, PartID, PartID)>;
54 
63  static Rank round_robin(
64  std::shared_ptr<Communicator> const& comm,
65  PartID pid,
66  [[maybe_unused]] PartID total_num_partitions
67  ) {
68  return safe_cast<Rank>(pid % safe_cast<PartID>(comm->nranks()));
69  }
70 
82  static Rank contiguous(
83  std::shared_ptr<Communicator> const& comm, PartID pid, PartID total_num_partitions
84  ) {
85  return safe_cast<Rank>(
86  (pid * safe_cast<PartID>(comm->nranks())) / total_num_partitions
87  );
88  }
89 
98  static std::vector<PartID> local_partitions(
99  std::shared_ptr<Communicator> const& comm,
102  );
103 
111  using FinishedCallback = std::function<void()>;
112 
133  std::shared_ptr<Communicator> comm,
134  OpID op_id,
136  BufferResource* br,
137  FinishedCallback&& finished_callback,
139  std::unique_ptr<communicator::MetadataPayloadExchange> mpe = nullptr
140  );
141 
159  std::shared_ptr<Communicator> comm,
160  OpID op_id,
162  BufferResource* br,
164  std::unique_ptr<communicator::MetadataPayloadExchange> mpe = nullptr
165  )
166  : Shuffler(
167  comm,
168  op_id,
170  br,
171  nullptr,
173  std::move(mpe)
174  ) {}
175 
176  ~Shuffler();
177 
183  [[nodiscard]] std::shared_ptr<Communicator> const& comm() const noexcept {
184  return comm_;
185  }
186 
187  Shuffler(Shuffler const&) = delete;
188  Shuffler& operator=(Shuffler const&) = delete;
189 
195  void shutdown();
196 
205  void insert(std::unordered_map<PartID, PackedData>&& chunks);
206 
218 
231  [[nodiscard]] std::vector<PackedData> extract(PartID pid);
232 
238  [[nodiscard]] bool finished() const;
239 
247  void wait(std::optional<std::chrono::milliseconds> timeout = {});
248 
262  std::size_t spill(std::optional<std::size_t> amount = std::nullopt);
263 
268  [[nodiscard]] std::string str() const;
269 
275  [[nodiscard]] std::span<PartID const> local_partitions() const;
276 
280  static constexpr int chunk_id_counter_bits = 38;
281 
287  static constexpr Rank extract_rank(detail::ChunkID cid) {
288  return safe_cast<Rank>(cid >> chunk_id_counter_bits);
289  }
290 
291  private:
297  void insert(detail::Chunk&& chunk);
298 
304  void insert_into_received(detail::Chunk&& chunk);
305 
307  [[nodiscard]] detail::ChunkID get_new_cid();
308 
317  [[nodiscard]] detail::Chunk create_chunk(PartID pid, PackedData&& packed_data);
318 
319  public:
322 
323  private:
324  BufferResource* br_;
325  std::atomic<bool> active_{true};
326  // Have we called `insert_finished()` on this rank.
327  std::atomic<bool> locally_finished_{false};
328  // Flipped to true exactly once when partitions are ready for extraction and we've
329  // posted all sends we're going to
330  bool can_extract_{false};
331  detail::ChunksToSend to_send_;
332  detail::ReceivedChunks received_;
334 
335  std::shared_ptr<Communicator> comm_;
336  std::unique_ptr<communicator::MetadataPayloadExchange> mpe_;
337  ProgressThread::FunctionID progress_thread_function_id_;
338 
339  SpillManager::SpillFunctionID spill_function_id_;
340 
341  std::vector<PartID> const local_partitions_;
342 
343  detail::FinishCounter finish_counter_;
344  std::vector<detail::ChunkID> outbound_chunk_counter_;
345  std::atomic<detail::ChunkID> chunk_id_counter_{0};
346 
347  std::shared_ptr<Statistics> statistics_;
348 
349  // For notifications.
350  mutable std::mutex mutex_;
351  std::condition_variable cv_;
352  FinishedCallback finished_callback_;
353 
354  class Progress;
355 };
356 
366 inline std::ostream& operator<<(std::ostream& os, Shuffler const& obj) {
367  os << obj.str();
368  return os;
369 }
370 
371 } // namespace rapidsmpf::shuffler
Class managing buffer resources.
std::size_t SpillFunctionID
Represents a unique identifier for a registered spill function.
Shuffle service for all-to-all style communication of partitioned data.
Definition: shuffler.hpp:46
Shuffler(std::shared_ptr< Communicator > comm, OpID op_id, PartID total_num_partitions, BufferResource *br, FinishedCallback &&finished_callback, PartitionOwner partition_owner=round_robin, std::unique_ptr< communicator::MetadataPayloadExchange > mpe=nullptr)
Construct a new shuffler for a single shuffle.
void shutdown()
Shutdown the shuffle, blocking until all inflight communication is done.
void insert(std::unordered_map< PartID, PackedData > &&chunks)
Insert a bunch of packed (serialized) chunks into the shuffle.
std::size_t spill(std::optional< std::size_t > amount=std::nullopt)
Spills data to device if necessary.
static constexpr int chunk_id_counter_bits
The number of bits used to store the counter in a chunk ID.
Definition: shuffler.hpp:280
static std::vector< PartID > local_partitions(std::shared_ptr< Communicator > const &comm, PartID total_num_partitions, PartitionOwner partition_owner)
Returns the local partition IDs owned by the current node.
static Rank contiguous(std::shared_ptr< Communicator > const &comm, PartID pid, PartID total_num_partitions)
A PartitionOwner that assigns contiguous partition ID ranges to ranks.
Definition: shuffler.hpp:82
std::function< Rank(std::shared_ptr< Communicator > const &, PartID, PartID)> PartitionOwner
Function that given a Communicator, PartID, and total partition count, returns the rapidsmpf::Rank of...
Definition: shuffler.hpp:53
void insert_finished()
Signal that no more data will be inserted into the shuffle.
PartID const total_num_partitions
Total number of partition in the shuffle.
Definition: shuffler.hpp:320
static Rank round_robin(std::shared_ptr< Communicator > const &comm, PartID pid, [[maybe_unused]] PartID total_num_partitions)
A PartitionOwner that distributes partitions using round robin assignment.
Definition: shuffler.hpp:63
std::string str() const
Returns a description of this instance.
static constexpr Rank extract_rank(detail::ChunkID cid)
Extract the rank from a chunk ID.
Definition: shuffler.hpp:287
bool finished() const
Check if all partitions are finished.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this Shuffler.
Definition: shuffler.hpp:183
Shuffler(std::shared_ptr< Communicator > comm, OpID op_id, PartID total_num_partitions, BufferResource *br, PartitionOwner partition_owner=round_robin, std::unique_ptr< communicator::MetadataPayloadExchange > mpe=nullptr)
Construct a new shuffler for a single shuffle.
Definition: shuffler.hpp:158
std::vector< PackedData > extract(PartID pid)
Extract all chunks belonging to the specified partition.
std::span< PartID const > local_partitions() const
Returns the local partition IDs owned by the shuffler`.
void wait(std::optional< std::chrono::milliseconds > timeout={})
Wait for all partitions to finish (blocking).
PartitionOwner const partition_owner
Function to determine partition ownership.
Definition: shuffler.hpp:321
std::function< void()> FinishedCallback
Callback function type called when all partitions are finished and data can be extracted.
Definition: shuffler.hpp:111
A partition chunk representing either a data message or a control message.
Definition: chunk.hpp:58
std::uint64_t ChunkID
The globally unique ID of a chunk.
Definition: chunk.hpp:29
Shuffler interfaces.
Definition: chunk.hpp:15
std::uint32_t PartID
Partition ID, which goes from 0 to the total number of partitions.
Definition: chunk.hpp:22
std::ostream & operator<<(std::ostream &os, detail::FinishCounter const &obj)
Overloads the stream insertion operator for the FinishCounter class.
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
Bag of bytes with metadata suitable for sending over the wire.
Definition: packed_data.hpp:26