allgather.hpp
1 
5 #pragma once
6 
7 #include <atomic>
8 #include <chrono>
9 #include <condition_variable>
10 #include <cstdint>
11 #include <functional>
12 #include <limits>
13 #include <memory>
14 #include <mutex>
15 #include <optional>
16 #include <vector>
17 
18 #include <rmm/cuda_stream_view.hpp>
19 
20 #include <rapidsmpf/communicator/communicator.hpp>
21 #include <rapidsmpf/error.hpp>
22 #include <rapidsmpf/memory/buffer.hpp>
23 #include <rapidsmpf/memory/buffer_resource.hpp>
24 #include <rapidsmpf/memory/packed_data.hpp>
25 #include <rapidsmpf/memory/spill_manager.hpp>
26 #include <rapidsmpf/progress_thread.hpp>
27 #include <rapidsmpf/statistics.hpp>
28 
36 namespace rapidsmpf::coll {
37 namespace detail {
38 
40 using ChunkID = std::uint64_t;
41 
51 class Chunk {
52  private:
53  ChunkID id_;
54  std::unique_ptr<std::vector<std::uint8_t>> metadata_;
55  std::unique_ptr<Buffer> data_;
56  std::uint64_t
57  data_size_;
59 
74  Chunk(
75  ChunkID id,
76  std::unique_ptr<std::vector<std::uint8_t>> metadata,
77  std::unique_ptr<Buffer> data
78  );
79 
88  Chunk(ChunkID id);
89 
90  public:
99  [[nodiscard]] bool is_ready() const noexcept;
100 
107  [[nodiscard]] MemoryType memory_type() const noexcept;
108 
114  [[nodiscard]] bool is_finish() const noexcept;
115 
121  [[nodiscard]] ChunkID id() const noexcept;
122 
128  [[nodiscard]] ChunkID sequence() const noexcept;
129 
135  [[nodiscard]] Rank origin() const noexcept;
136 
142  [[nodiscard]] std::uint64_t data_size() const noexcept;
143 
149  [[nodiscard]] std::uint64_t metadata_size() const noexcept;
150 
159  [[nodiscard]] static std::unique_ptr<Chunk> from_packed_data(
160  std::uint64_t sequence, Rank origin, PackedData&& packed_data
161  );
162 
171  [[nodiscard]] static std::unique_ptr<Chunk> from_empty(
172  std::uint64_t num_local_insertions, Rank origin
173  );
174 
185  [[nodiscard]] PackedData release();
186 
188  static constexpr std::uint64_t ID_BITS = 38;
190  static constexpr std::uint64_t RANK_BITS =
191  sizeof(ChunkID) * std::numeric_limits<unsigned char>::digits - ID_BITS;
192 
201  static constexpr ChunkID chunk_id(std::uint64_t sequence, Rank origin);
202 
208  [[nodiscard]] std::unique_ptr<std::vector<std::uint8_t>> serialize() const;
209 
220  [[nodiscard]] static std::unique_ptr<Chunk> deserialize(
221  std::vector<std::uint8_t>& data, BufferResource* br
222  );
223 
229  [[nodiscard]] std::unique_ptr<Buffer> release_data_buffer() noexcept;
230 
239  void attach_data_buffer(std::unique_ptr<Buffer> data);
240 
242  ~Chunk() = default;
244  Chunk(Chunk&&) = default;
247  Chunk& operator=(Chunk&&) = default;
249  Chunk(Chunk const&) = delete;
251  Chunk& operator=(Chunk const&) = delete;
252 };
253 
262 class PostBox {
263  public:
265  PostBox() = default;
267  ~PostBox() = default;
269  PostBox(PostBox const&) = delete;
271  PostBox& operator=(PostBox const&) = delete;
273  PostBox(PostBox&&) = delete;
275  PostBox& operator=(PostBox&&) = delete;
276 
282  void insert(std::unique_ptr<Chunk> chunk);
283 
289  void insert(std::vector<std::unique_ptr<Chunk>>&& chunks);
290 
296  void increment_goalpost(std::uint64_t amount);
297 
304  [[nodiscard]] bool ready() const noexcept;
305 
314  [[nodiscard]] std::vector<std::unique_ptr<Chunk>> extract_ready();
315 
324  [[nodiscard]] std::vector<std::unique_ptr<Chunk>> extract();
325 
331  [[nodiscard]] bool empty() const noexcept;
332 
345  [[nodiscard]] std::size_t spill(BufferResource* br, std::size_t amount);
346 
347  private:
348  mutable std::mutex mutex_{};
349  std::vector<std::unique_ptr<Chunk>> chunks_{};
350  std::atomic<std::uint64_t> goalpost_{0};
351 };
352 
353 } // namespace detail
354 
376 class AllGather {
377  public:
384  void insert(std::uint64_t sequence_number, PackedData&& packed_data);
385 
390 
397  [[nodiscard]] bool finished() const noexcept;
398 
400  enum class Ordered : bool {
401  NO,
402  YES,
403  };
404 
419  [[nodiscard]] std::vector<PackedData> wait_and_extract(
420  Ordered ordered = Ordered::YES,
421  std::chrono::milliseconds timeout = std::chrono::milliseconds{-1}
422  );
423 
444  [[nodiscard]] std::vector<PackedData> extract_ready();
445 
466  std::shared_ptr<Communicator> comm,
467  OpID op_id,
468  BufferResource* br,
469  std::shared_ptr<Statistics> statistics = Statistics::disabled(),
470  std::function<void(void)>&& finished_callback = nullptr
471  );
472 
474  AllGather(AllGather const&) = delete;
476  AllGather& operator=(AllGather const&) = delete;
478  AllGather(AllGather&&) = delete;
481 
487  [[nodiscard]] std::shared_ptr<Communicator> const& comm() const noexcept {
488  return comm_;
489  }
490 
498  ~AllGather() noexcept;
499 
508  ProgressThread::ProgressState event_loop();
509 
510  private:
516  void insert(std::unique_ptr<detail::Chunk> chunk);
517 
524  void mark_finish(std::uint64_t expected_chunks) noexcept;
525 
533  void wait(std::chrono::milliseconds timeout = std::chrono::milliseconds{-1});
534 
542  std::size_t spill(std::optional<std::size_t> amount = std::nullopt);
543 
544  std::shared_ptr<Communicator> comm_;
545  BufferResource* br_;
546  std::shared_ptr<Statistics> statistics_;
547  std::function<void(void)> finished_callback_{
548  nullptr
549  };
550  std::atomic<Rank> finish_counter_;
551  std::atomic<std::uint32_t> nlocal_insertions_;
552  OpID op_id_;
553  std::atomic<bool> locally_finished_{false};
554  std::atomic<bool> active_{true};
555  bool can_extract_{false};
556  mutable std::mutex mutex_;
557  std::condition_variable cv_;
558  detail::PostBox inserted_{};
559  detail::PostBox for_extraction_{};
560  ProgressThread::FunctionID function_id_{};
561  SpillManager::SpillFunctionID spill_function_id_{};
562  // We track remote finishes separately from the finish_counter_ above since the path
563  // through the event loop state machine for a local finish marker is slightly
564  // different from a remote finish marker.
566  Rank remote_finish_counter_;
568  std::uint64_t num_expected_messages_{0};
570  std::uint64_t num_received_messages_{0};
572  std::vector<std::unique_ptr<detail::Chunk>> to_receive_{};
574  std::vector<std::unique_ptr<Communicator::Future>> fire_and_forget_{};
576  std::vector<std::unique_ptr<detail::Chunk>> sent_posted_{};
578  std::vector<std::unique_ptr<Communicator::Future>> sent_futures_{};
580  std::vector<std::unique_ptr<detail::Chunk>> receive_posted_{};
582  std::vector<std::unique_ptr<Communicator::Future>> receive_futures_{};
583 };
584 
585 } // namespace rapidsmpf::coll
Class managing buffer resources.
Buffer representing device or host memory.
Definition: buffer.hpp:47
A progress thread that can execute arbitrary functions.
std::size_t SpillFunctionID
Represents a unique identifier for a registered spill function.
static std::shared_ptr< Statistics > disabled()
Returns a shared pointer to a disabled (no-op) Statistics instance.
AllGather communication service.
Definition: allgather.hpp:376
AllGather & operator=(AllGather const &)=delete
Deleted copy assignment operator.
AllGather(AllGather &&)=delete
Deleted move constructor.
~AllGather() noexcept
Destructor.
Ordered
Tag requesting ordering for extraction.
Definition: allgather.hpp:400
void insert(std::uint64_t sequence_number, PackedData &&packed_data)
Insert packed data into the allgather operation.
AllGather & operator=(AllGather &&)=delete
Deleted move assignment operator.
std::vector< PackedData > extract_ready()
Extract any available partitions.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this AllGather.
Definition: allgather.hpp:487
std::vector< PackedData > wait_and_extract(Ordered ordered=Ordered::YES, std::chrono::milliseconds timeout=std::chrono::milliseconds{-1})
Wait for completion and extract all gathered data.
AllGather(std::shared_ptr< Communicator > comm, OpID op_id, BufferResource *br, std::shared_ptr< Statistics > statistics=Statistics::disabled(), std::function< void(void)> &&finished_callback=nullptr)
Construct a new allgather operation.
AllGather(AllGather const &)=delete
Deleted copy constructor.
bool finished() const noexcept
Check if the allgather operation has completed.
void insert_finished()
Mark that this rank has finished contributing data.
Represents a data chunk in the allgather operation.
Definition: allgather.hpp:51
bool is_ready() const noexcept
Check if the chunk is ready for processing.
static std::unique_ptr< Chunk > deserialize(std::vector< std::uint8_t > &data, BufferResource *br)
Deserialize a chunk from a byte vector of its metadata.
std::unique_ptr< Buffer > release_data_buffer() noexcept
Release and return the data buffer.
void attach_data_buffer(std::unique_ptr< Buffer > data)
Attach a data buffer to this chunk.
static std::unique_ptr< Chunk > from_packed_data(std::uint64_t sequence, Rank origin, PackedData &&packed_data)
Create a data chunk from packed data.
Rank origin() const noexcept
The origin rank of the chunk.
static constexpr std::uint64_t RANK_BITS
Number of bits used for the rank in the chunk identifier.
Definition: allgather.hpp:190
std::uint64_t data_size() const noexcept
The size of the data buffer in bytes.
static std::unique_ptr< Chunk > from_empty(std::uint64_t num_local_insertions, Rank origin)
Create an empty finish marker chunk.
std::uint64_t metadata_size() const noexcept
The size of the metadata buffer in bytes.
static constexpr std::uint64_t ID_BITS
Number of bits used for the sequence ID in the chunk identifier.
Definition: allgather.hpp:188
std::unique_ptr< std::vector< std::uint8_t > > serialize() const
Serialize the metadata of the chunk to a byte vector.
ChunkID sequence() const noexcept
The sequence number of the chunk.
bool is_finish() const noexcept
Check if this is a finish marker chunk.
PackedData release()
Release the chunk's data as PackedData.
MemoryType memory_type() const noexcept
Return the memory type of the chunk.
static constexpr ChunkID chunk_id(std::uint64_t sequence, Rank origin)
Create a ChunkID from a sequence number and origin rank.
A thread-safe container for managing chunks in an AllGather.
Definition: allgather.hpp:262
void insert(std::unique_ptr< Chunk > chunk)
Insert a single chunk into the postbox.
PostBox & operator=(PostBox &&)=delete
Deleted move assignment operator.
bool ready() const noexcept
Check if the postbox has reached its goal.
~PostBox()=default
Default destructor.
void insert(std::vector< std::unique_ptr< Chunk >> &&chunks)
Insert multiple chunks into the postbox.
PostBox()=default
Default constructor.
PostBox(PostBox &&)=delete
Deleted move constructor.
PostBox & operator=(PostBox const &)=delete
Deleted copy assignment operator.
PostBox(PostBox const &)=delete
Deleted copy constructor.
void increment_goalpost(std::uint64_t amount)
Increment the goalpost to a new expected chunk count.
Collective communication interfaces.
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
@ YES
Overbooking is allowed.
@ NO
Overbooking is not allowed.
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
MemoryType
Enum representing the type of memory sorted in decreasing order of preference.
Definition: memory_type.hpp:16
Bag of bytes with metadata suitable for sending over the wire.
Definition: packed_data.hpp:26