sparse_alltoall.hpp
1 
5 #pragma once
6 
7 #include <atomic>
8 #include <chrono>
9 #include <condition_variable>
10 #include <cstdint>
11 #include <functional>
12 #include <memory>
13 #include <mutex>
14 #include <unordered_map>
15 #include <vector>
16 
17 #include <rapidsmpf/coll/utils.hpp>
18 #include <rapidsmpf/communicator/communicator.hpp>
19 #include <rapidsmpf/memory/buffer_resource.hpp>
20 #include <rapidsmpf/memory/packed_data.hpp>
21 #include <rapidsmpf/progress_thread.hpp>
22 
23 namespace rapidsmpf::coll {
24 
36  public:
72  std::shared_ptr<Communicator> comm,
73  OpID op_id,
74  BufferResource* br,
75  std::vector<Rank> srcs,
76  std::vector<Rank> dsts,
77  std::function<void()>&& finished_callback = nullptr
78  );
79 
80  ~SparseAlltoall() noexcept;
81 
82  SparseAlltoall(SparseAlltoall const&) = delete;
83  SparseAlltoall& operator=(SparseAlltoall const&) = delete;
84  SparseAlltoall(SparseAlltoall&&) = delete;
85  SparseAlltoall& operator=(SparseAlltoall&&) = delete;
86 
92  [[nodiscard]] std::shared_ptr<Communicator> const& comm() const noexcept;
93 
110  void insert(Rank dst, PackedData&& packed_data);
111 
122 
130  void wait(std::chrono::milliseconds timeout = std::chrono::milliseconds{-1});
131 
145  [[nodiscard]] std::vector<PackedData> extract(Rank src);
146 
147  private:
148  struct SourceState {
149  std::uint64_t expected_count{0};
150  std::uint64_t received_count{0};
151  std::vector<std::unique_ptr<detail::Chunk>> chunks{};
152  std::vector<std::unique_ptr<detail::Chunk>> incoming{};
153 
154  [[nodiscard]] bool ready() const noexcept {
155  return expected_count > 0 && expected_count == received_count;
156  }
157  };
158 
160  void send_ready_messages();
162  void receive_metadata_messages();
164  void receive_data_messages();
166  void complete_data_messages();
168  [[nodiscard]] bool containers_empty() const;
170  [[nodiscard]] ProgressThread::ProgressState event_loop();
171 
172  std::shared_ptr<Communicator> comm_;
173  BufferResource* br_;
174  std::vector<Rank> srcs_;
175  std::vector<Rank> dsts_;
176  std::unordered_map<Rank, std::atomic<std::uint64_t>> next_ordinal_per_dst_;
177 
178  mutable std::mutex mutex_;
179  std::condition_variable cv_;
180  OpID op_id_;
181  std::atomic<bool> locally_finished_{false};
182  bool can_extract_{false};
183 
184  detail::PostBox outgoing_{};
185  std::vector<std::unique_ptr<detail::Chunk>> receive_posted_;
186  std::vector<std::unique_ptr<Communicator::Future>> receive_futures_;
187  std::vector<std::unique_ptr<Communicator::Future>> fire_and_forget_;
188  std::unordered_map<Rank, SourceState> source_states_;
189  std::function<void()> finished_callback_;
190  ProgressThread::FunctionID function_id_;
191 };
192 
193 } // namespace rapidsmpf::coll
Class managing buffer resources.
Abstract base class for a communication mechanism between nodes.
ProgressState
The progress state of a function, can be either InProgress or Done.
Sparse all-to-all collective over explicit source and destination peer sets.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this SparseAlltoall.
std::vector< PackedData > extract(Rank src)
Extract all received messages from a source rank.
SparseAlltoall(std::shared_ptr< Communicator > comm, OpID op_id, BufferResource *br, std::vector< Rank > srcs, std::vector< Rank > dsts, std::function< void()> &&finished_callback=nullptr)
Construct a sparse all-to-all collective instance.
void wait(std::chrono::milliseconds timeout=std::chrono::milliseconds{-1})
Wait for local completion.
void insert(Rank dst, PackedData &&packed_data)
Insert data to send to a destination rank.
void insert_finished()
Indicate that no more data will be inserted for any destination.
Collective communication interfaces.
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
Bag of bytes with metadata suitable for sending over the wire.
Definition: packed_data.hpp:26