allgather.hpp
1 
5 #pragma once
6 
7 #include <atomic>
8 #include <chrono>
9 #include <condition_variable>
10 #include <cstdint>
11 #include <functional>
12 #include <limits>
13 #include <memory>
14 #include <mutex>
15 #include <optional>
16 #include <vector>
17 
18 #include <rmm/cuda_stream_view.hpp>
19 
20 #include <rapidsmpf/buffer/buffer.hpp>
21 #include <rapidsmpf/buffer/packed_data.hpp>
22 #include <rapidsmpf/buffer/resource.hpp>
23 #include <rapidsmpf/buffer/spill_manager.hpp>
24 #include <rapidsmpf/communicator/communicator.hpp>
25 #include <rapidsmpf/error.hpp>
26 #include <rapidsmpf/progress_thread.hpp>
27 #include <rapidsmpf/statistics.hpp>
28 
36 namespace rapidsmpf::allgather {
37 namespace detail {
38 
40 using ChunkID = std::uint64_t;
41 
51 class Chunk {
52  private:
53  ChunkID id_;
54  std::unique_ptr<std::vector<std::uint8_t>> metadata_;
55  std::unique_ptr<Buffer> data_;
56  std::uint64_t
57  data_size_;
59 
67  Chunk(
68  ChunkID id,
69  std::unique_ptr<std::vector<std::uint8_t>> metadata,
70  std::unique_ptr<Buffer> data
71  );
72 
81  Chunk(ChunkID id);
82 
83  public:
92  [[nodiscard]] bool is_ready() const noexcept;
93 
100  [[nodiscard]] MemoryType memory_type() const noexcept;
101 
107  [[nodiscard]] bool is_finish() const noexcept;
108 
114  [[nodiscard]] ChunkID id() const noexcept;
115 
121  [[nodiscard]] ChunkID sequence() const noexcept;
122 
128  [[nodiscard]] Rank origin() const noexcept;
129 
135  [[nodiscard]] std::uint64_t data_size() const noexcept;
136 
142  [[nodiscard]] std::uint64_t metadata_size() const noexcept;
143 
152  [[nodiscard]] static std::unique_ptr<Chunk> from_packed_data(
153  std::uint64_t sequence, Rank origin, PackedData&& packed_data
154  );
155 
164  [[nodiscard]] static std::unique_ptr<Chunk> from_empty(
165  std::uint64_t num_local_insertions, Rank origin
166  );
167 
178  [[nodiscard]] PackedData release();
179 
181  static constexpr std::uint64_t ID_BITS = 38;
183  static constexpr std::uint64_t RANK_BITS =
184  sizeof(ChunkID) * std::numeric_limits<unsigned char>::digits - ID_BITS;
185 
194  static constexpr ChunkID chunk_id(std::uint64_t sequence, Rank origin);
195 
201  [[nodiscard]] std::unique_ptr<std::vector<std::uint8_t>> serialize() const;
202 
213  [[nodiscard]] static std::unique_ptr<Chunk> deserialize(
214  std::vector<std::uint8_t>& data, BufferResource* br
215  );
216 
222  [[nodiscard]] std::unique_ptr<Buffer> release_data_buffer() noexcept;
223 
232  void attach_data_buffer(std::unique_ptr<Buffer> data);
233 
235  ~Chunk() = default;
237  Chunk(Chunk&&) = default;
240  Chunk& operator=(Chunk&&) = default;
242  Chunk(Chunk const&) = delete;
244  Chunk& operator=(Chunk const&) = delete;
245 };
246 
255 class PostBox {
256  public:
258  PostBox() = default;
260  ~PostBox() = default;
262  PostBox(PostBox const&) = delete;
264  PostBox& operator=(PostBox const&) = delete;
266  PostBox(PostBox&&) = delete;
268  PostBox& operator=(PostBox&&) = delete;
269 
275  void insert(std::unique_ptr<Chunk> chunk);
276 
282  void insert(std::vector<std::unique_ptr<Chunk>>&& chunks);
283 
289  void increment_goalpost(std::uint64_t amount);
290 
297  [[nodiscard]] bool ready() const noexcept;
298 
307  [[nodiscard]] std::vector<std::unique_ptr<Chunk>> extract_ready();
308 
317  [[nodiscard]] std::vector<std::unique_ptr<Chunk>> extract();
318 
324  [[nodiscard]] bool empty() const noexcept;
325 
339  [[nodiscard]] std::size_t spill(
340  BufferResource* br, Communicator::Logger& log, std::size_t amount
341  );
342 
343  private:
344  mutable std::mutex mutex_{};
345  std::vector<std::unique_ptr<Chunk>> chunks_{};
346  std::atomic<std::uint64_t> goalpost_{0};
347 };
348 
349 } // namespace detail
350 
372 class AllGather {
373  public:
380  void insert(std::uint64_t sequence_number, PackedData&& packed_data);
381 
386 
393  [[nodiscard]] bool finished() const noexcept;
394 
396  enum class Ordered : bool {
397  NO,
398  YES,
399  };
414  [[nodiscard]] std::vector<PackedData> wait_and_extract(
415  Ordered ordered = Ordered::YES,
416  std::chrono::milliseconds timeout = std::chrono::milliseconds{-1}
417  );
418 
439  [[nodiscard]] std::vector<PackedData> extract_ready();
440 
459  std::shared_ptr<Communicator> comm,
460  std::shared_ptr<ProgressThread> progress_thread,
461  OpID op_id,
462  BufferResource* br,
463  std::shared_ptr<Statistics> statistics = Statistics::disabled(),
464  std::function<void(void)>&& finished_callback = nullptr
465  );
466 
468  AllGather(AllGather const&) = delete;
470  AllGather& operator=(AllGather const&) = delete;
472  AllGather(AllGather&&) = delete;
483 
493 
494  private:
500  void insert(std::unique_ptr<detail::Chunk> chunk);
501 
508  void mark_finish(std::uint64_t expected_chunks) noexcept;
509 
517  void wait(std::chrono::milliseconds timeout = std::chrono::milliseconds{-1});
518 
526  std::size_t spill(std::optional<std::size_t> amount = std::nullopt);
527 
528  std::shared_ptr<Communicator> comm_;
529  std::shared_ptr<ProgressThread>
530  progress_thread_;
531  BufferResource* br_;
532  std::shared_ptr<Statistics> statistics_;
533  std::function<void(void)> finished_callback_{
534  nullptr
535  };
536  std::atomic<Rank> finish_counter_;
537  std::atomic<std::uint32_t> nlocal_insertions_;
538  OpID op_id_;
539  std::atomic<bool> locally_finished_{false};
540  std::atomic<bool> active_{true};
541  bool can_extract_{false};
542  mutable std::mutex mutex_;
543  std::condition_variable cv_;
544  detail::PostBox inserted_{};
545  detail::PostBox for_extraction_{};
546  ProgressThread::FunctionID function_id_{};
547  SpillManager::SpillFunctionID spill_function_id_{};
549  std::vector<std::unique_ptr<detail::Chunk>> to_receive_{};
551  std::vector<std::unique_ptr<Communicator::Future>> fire_and_forget_{};
553  std::vector<std::unique_ptr<detail::Chunk>> sent_posted_{};
555  std::vector<std::unique_ptr<Communicator::Future>> sent_futures_{};
557  std::vector<std::unique_ptr<detail::Chunk>> receive_posted_{};
559  std::vector<std::unique_ptr<Communicator::Future>> receive_futures_{};
560 };
561 
562 } // namespace rapidsmpf::allgather
Class managing buffer resources.
Definition: resource.hpp:133
Buffer representing device or host memory.
Definition: buffer.hpp:53
Abstract base class for a communication mechanism between nodes.
ProgressState
The progress state of a function, can be either InProgress or Done.
std::size_t SpillFunctionID
Represents a unique identifier for a registered spill function.
static std::shared_ptr< Statistics > disabled()
Returns a shared pointer to a disabled (no-op) Statistics instance.
AllGather communication service.
Definition: allgather.hpp:372
ProgressThread::ProgressState event_loop()
Main event loop for processing allgather operations.
void insert(std::uint64_t sequence_number, PackedData &&packed_data)
Insert packed data into the allgather operation.
std::vector< PackedData > wait_and_extract(Ordered ordered=Ordered::YES, std::chrono::milliseconds timeout=std::chrono::milliseconds{-1})
Wait for completion and extract all gathered data.
AllGather(std::shared_ptr< Communicator > comm, std::shared_ptr< ProgressThread > progress_thread, OpID op_id, BufferResource *br, std::shared_ptr< Statistics > statistics=Statistics::disabled(), std::function< void(void)> &&finished_callback=nullptr)
Construct a new allgather operation.
Ordered
Tag requesting ordering for extraction.
Definition: allgather.hpp:396
void insert_finished()
Mark that this rank has finished contributing data.
AllGather(AllGather const &)=delete
Deleted copy constructor.
std::vector< PackedData > extract_ready()
Extract any available partitions.
AllGather & operator=(AllGather const &)=delete
Deleted copy assignment operator.
AllGather & operator=(AllGather &&)=delete
Deleted move assignment operator.
AllGather(AllGather &&)=delete
Deleted move constructor.
bool finished() const noexcept
Check if the allgather operation has completed.
Represents a data chunk in the allgather operation.
Definition: allgather.hpp:51
PackedData release()
Release the chunk's data as PackedData.
bool is_finish() const noexcept
Check if this is a finish marker chunk.
static std::unique_ptr< Chunk > deserialize(std::vector< std::uint8_t > &data, BufferResource *br)
Deserialize a chunk from a byte vector of its metadata.
Rank origin() const noexcept
The origin rank of the chunk.
static constexpr std::uint64_t ID_BITS
Number of bits used for the sequence ID in the chunk identifier.
Definition: allgather.hpp:181
ChunkID sequence() const noexcept
The sequence number of the chunk.
void attach_data_buffer(std::unique_ptr< Buffer > data)
Attach a data buffer to this chunk.
std::unique_ptr< std::vector< std::uint8_t > > serialize() const
Serialize the metadata of the chunk to a byte vector.
static std::unique_ptr< Chunk > from_empty(std::uint64_t num_local_insertions, Rank origin)
Create an empty finish marker chunk.
std::uint64_t data_size() const noexcept
The size of the data buffer in bytes.
MemoryType memory_type() const noexcept
Return the memory type of the chunk.
bool is_ready() const noexcept
Check if the chunk is ready for processing.
static constexpr ChunkID chunk_id(std::uint64_t sequence, Rank origin)
Create a ChunkID from a sequence number and origin rank.
static constexpr std::uint64_t RANK_BITS
Number of bits used for the rank in the chunk identifier.
Definition: allgather.hpp:183
std::unique_ptr< Buffer > release_data_buffer() noexcept
Release and return the data buffer.
static std::unique_ptr< Chunk > from_packed_data(std::uint64_t sequence, Rank origin, PackedData &&packed_data)
Create a data chunk from packed data.
std::uint64_t metadata_size() const noexcept
The size of the metadata buffer in bytes.
A thread-safe container for managing chunks in an AllGather.
Definition: allgather.hpp:255
void increment_goalpost(std::uint64_t amount)
Increment the goalpost to a new expected chunk count.
~PostBox()=default
Default destructor.
PostBox & operator=(PostBox const &)=delete
Deleted copy assignment operator.
bool ready() const noexcept
Check if the postbox has reached its goal.
PostBox()=default
Default constructor.
PostBox(PostBox const &)=delete
Deleted copy constructor.
PostBox & operator=(PostBox &&)=delete
Deleted move assignment operator.
void insert(std::vector< std::unique_ptr< Chunk >> &&chunks)
Insert multiple chunks into the postbox.
void insert(std::unique_ptr< Chunk > chunk)
Insert a single chunk into the postbox.
PostBox(PostBox &&)=delete
Deleted move constructor.
Allgather communication interfaces.
Bag of bytes with metadata suitable for sending over the wire.
Definition: packed_data.hpp:26