allgather.hpp
1 
5 #pragma once
6 
7 #include <atomic>
8 #include <chrono>
9 #include <condition_variable>
10 #include <cstdint>
11 #include <functional>
12 #include <memory>
13 #include <mutex>
14 #include <optional>
15 #include <vector>
16 
17 #include <rmm/cuda_stream_view.hpp>
18 
19 #include <rapidsmpf/coll/utils.hpp>
20 #include <rapidsmpf/communicator/communicator.hpp>
21 #include <rapidsmpf/memory/buffer.hpp>
22 #include <rapidsmpf/memory/buffer_resource.hpp>
23 #include <rapidsmpf/memory/packed_data.hpp>
24 #include <rapidsmpf/memory/spill_manager.hpp>
25 #include <rapidsmpf/progress_thread.hpp>
26 #include <rapidsmpf/statistics.hpp>
27 
35 namespace rapidsmpf::coll {
36 
58 class AllGather {
59  public:
66  void insert(std::uint64_t sequence_number, PackedData&& packed_data);
67 
72 
74  enum class Ordered : bool {
75  NO,
76  YES,
77  };
78 
93  [[nodiscard]] std::vector<PackedData> wait_and_extract(
94  Ordered ordered = Ordered::YES,
95  std::chrono::milliseconds timeout = std::chrono::milliseconds{-1}
96  );
97 
118  std::shared_ptr<Communicator> comm,
119  OpID op_id,
120  BufferResource* br,
121  std::shared_ptr<Statistics> statistics = Statistics::disabled(),
122  std::function<void(void)>&& finished_callback = nullptr
123  );
124 
126  AllGather(AllGather const&) = delete;
128  AllGather& operator=(AllGather const&) = delete;
130  AllGather(AllGather&&) = delete;
133 
139  [[nodiscard]] std::shared_ptr<Communicator> const& comm() const noexcept {
140  return comm_;
141  }
142 
150  ~AllGather() noexcept;
151 
160  ProgressThread::ProgressState event_loop();
161 
162  private:
168  void insert(std::unique_ptr<detail::Chunk> chunk);
169 
176  void mark_finish(std::uint64_t expected_chunks) noexcept;
177 
185  void wait(std::chrono::milliseconds timeout = std::chrono::milliseconds{-1});
186 
194  std::size_t spill(std::optional<std::size_t> amount = std::nullopt);
195 
196  std::shared_ptr<Communicator> comm_;
197  BufferResource* br_;
198  std::shared_ptr<Statistics> statistics_;
199  std::function<void(void)> finished_callback_{
200  nullptr
201  };
202  std::atomic<Rank> finish_counter_;
203  std::atomic<std::uint32_t> nlocal_insertions_;
204  std::atomic<std::uint64_t> extraction_goalpost_{
205  0
206  };
207  OpID op_id_;
208  std::atomic<bool> locally_finished_{false};
209  bool can_extract_{false};
210  mutable std::mutex mutex_;
211  std::condition_variable cv_;
212  detail::PostBox inserted_{};
213  detail::PostBox for_extraction_{};
214  ProgressThread::FunctionID function_id_{};
215  SpillManager::SpillFunctionID spill_function_id_{};
216  // We track remote finishes separately from the finish_counter_ above since the path
217  // through the event loop state machine for a local finish marker is slightly
218  // different from a remote finish marker.
220  Rank remote_finish_counter_;
222  std::uint64_t num_expected_messages_{0};
224  std::uint64_t num_received_messages_{0};
226  std::vector<std::unique_ptr<detail::Chunk>> to_receive_{};
228  std::vector<std::unique_ptr<Communicator::Future>> fire_and_forget_{};
230  std::vector<std::unique_ptr<detail::Chunk>> sent_posted_{};
232  std::vector<std::unique_ptr<Communicator::Future>> sent_futures_{};
234  std::vector<std::unique_ptr<detail::Chunk>> receive_posted_{};
236  std::vector<std::unique_ptr<Communicator::Future>> receive_futures_{};
237 };
238 
239 } // namespace rapidsmpf::coll
Class managing buffer resources.
A progress thread that can execute arbitrary functions.
std::size_t SpillFunctionID
Represents a unique identifier for a registered spill function.
static std::shared_ptr< Statistics > disabled()
Returns a shared pointer to a disabled (no-op) Statistics instance.
AllGather communication service.
Definition: allgather.hpp:58
AllGather & operator=(AllGather const &)=delete
Deleted copy assignment operator.
AllGather(AllGather &&)=delete
Deleted move constructor.
~AllGather() noexcept
Destructor.
Ordered
Tag requesting ordering for extraction.
Definition: allgather.hpp:74
@ YES
Extraction is ordered.
@ NO
Extraction is unordered.
void insert(std::uint64_t sequence_number, PackedData &&packed_data)
Insert packed data into the allgather operation.
AllGather & operator=(AllGather &&)=delete
Deleted move assignment operator.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this AllGather.
Definition: allgather.hpp:139
std::vector< PackedData > wait_and_extract(Ordered ordered=Ordered::YES, std::chrono::milliseconds timeout=std::chrono::milliseconds{-1})
Wait for completion and extract all gathered data.
ProgressThread::ProgressState event_loop()
Main event loop for processing allgather operations.
AllGather(std::shared_ptr< Communicator > comm, OpID op_id, BufferResource *br, std::shared_ptr< Statistics > statistics=Statistics::disabled(), std::function< void(void)> &&finished_callback=nullptr)
Construct a new allgather operation.
AllGather(AllGather const &)=delete
Deleted copy constructor.
void insert_finished()
Mark that this rank has finished contributing data.
Collective communication interfaces.
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
Bag of bytes with metadata suitable for sending over the wire.
Definition: packed_data.hpp:26