9 #include <condition_variable>
14 #include <unordered_map>
17 #include <rapidsmpf/coll/utils.hpp>
18 #include <rapidsmpf/communicator/communicator.hpp>
19 #include <rapidsmpf/memory/buffer_resource.hpp>
20 #include <rapidsmpf/memory/packed_data.hpp>
21 #include <rapidsmpf/progress_thread.hpp>
72 std::shared_ptr<Communicator>
comm,
75 std::vector<Rank> srcs,
76 std::vector<Rank> dsts,
77 std::function<
void()>&& finished_callback =
nullptr
130 void wait(std::chrono::milliseconds timeout = std::chrono::milliseconds{-1});
149 std::uint64_t expected_count{0};
150 std::uint64_t received_count{0};
151 std::vector<std::unique_ptr<detail::Chunk>> chunks{};
152 std::vector<std::unique_ptr<detail::Chunk>> incoming{};
154 [[nodiscard]]
bool ready() const noexcept {
155 return expected_count > 0 && expected_count == received_count;
160 void send_ready_messages();
162 void receive_metadata_messages();
164 void receive_data_messages();
166 void complete_data_messages();
168 [[nodiscard]]
bool containers_empty()
const;
172 std::shared_ptr<Communicator> comm_;
174 std::vector<Rank> srcs_;
175 std::vector<Rank> dsts_;
176 std::unordered_map<Rank, std::atomic<std::uint64_t>> next_ordinal_per_dst_;
178 mutable std::mutex mutex_;
179 std::condition_variable cv_;
181 std::atomic<bool> locally_finished_{
false};
182 bool can_extract_{
false};
184 detail::PostBox outgoing_{};
185 std::vector<std::unique_ptr<detail::Chunk>> receive_posted_;
186 std::vector<std::unique_ptr<Communicator::Future>> receive_futures_;
187 std::vector<std::unique_ptr<Communicator::Future>> fire_and_forget_;
188 std::unordered_map<Rank, SourceState> source_states_;
189 std::function<void()> finished_callback_;
190 ProgressThread::FunctionID function_id_;
Class managing buffer resources.
Abstract base class for a communication mechanism between nodes.
ProgressState
The progress state of a function, can be either InProgress or Done.
Sparse all-to-all collective over explicit source and destination peer sets.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this SparseAlltoall.
std::vector< PackedData > extract(Rank src)
Extract all received messages from a source rank.
SparseAlltoall(std::shared_ptr< Communicator > comm, OpID op_id, BufferResource *br, std::vector< Rank > srcs, std::vector< Rank > dsts, std::function< void()> &&finished_callback=nullptr)
Construct a sparse all-to-all collective instance.
void wait(std::chrono::milliseconds timeout=std::chrono::milliseconds{-1})
Wait for local completion.
void insert(Rank dst, PackedData &&packed_data)
Insert data to send to a destination rank.
void insert_finished()
Indicate that no more data will be inserted for any destination.
Collective communication interfaces.
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
Bag of bytes with metadata suitable for sending over the wire.