9 #include <condition_variable>
19 #include <rapidsmpf/coll/utils.hpp>
20 #include <rapidsmpf/communicator/communicator.hpp>
21 #include <rapidsmpf/memory/buffer.hpp>
22 #include <rapidsmpf/memory/buffer_resource.hpp>
23 #include <rapidsmpf/memory/packed_data.hpp>
24 #include <rapidsmpf/memory/spill_manager.hpp>
25 #include <rapidsmpf/progress_thread.hpp>
26 #include <rapidsmpf/statistics.hpp>
95 std::chrono::milliseconds timeout = std::chrono::milliseconds{-1}
118 std::shared_ptr<Communicator>
comm,
122 std::function<
void(
void)>&& finished_callback =
nullptr
139 [[nodiscard]] std::shared_ptr<Communicator>
const&
comm() const noexcept {
168 void insert(std::unique_ptr<detail::Chunk> chunk);
176 void mark_finish(std::uint64_t expected_chunks) noexcept;
185 void wait(std::chrono::milliseconds timeout = std::chrono::milliseconds{-1});
194 std::size_t spill(std::optional<std::size_t> amount = std::nullopt);
196 std::shared_ptr<Communicator> comm_;
198 std::shared_ptr<Statistics> statistics_;
199 std::function<void(
void)> finished_callback_{
202 std::atomic<Rank> finish_counter_;
203 std::atomic<std::uint32_t> nlocal_insertions_;
204 std::atomic<std::uint64_t> extraction_goalpost_{
208 std::atomic<bool> locally_finished_{
false};
209 bool can_extract_{
false};
210 mutable std::mutex mutex_;
211 std::condition_variable cv_;
212 detail::PostBox inserted_{};
213 detail::PostBox for_extraction_{};
214 ProgressThread::FunctionID function_id_{};
220 Rank remote_finish_counter_;
222 std::uint64_t num_expected_messages_{0};
224 std::uint64_t num_received_messages_{0};
226 std::vector<std::unique_ptr<detail::Chunk>> to_receive_{};
228 std::vector<std::unique_ptr<Communicator::Future>> fire_and_forget_{};
230 std::vector<std::unique_ptr<detail::Chunk>> sent_posted_{};
232 std::vector<std::unique_ptr<Communicator::Future>> sent_futures_{};
234 std::vector<std::unique_ptr<detail::Chunk>> receive_posted_{};
236 std::vector<std::unique_ptr<Communicator::Future>> receive_futures_{};
Class managing buffer resources.
A progress thread that can execute arbitrary functions.
std::size_t SpillFunctionID
Represents a unique identifier for a registered spill function.
static std::shared_ptr< Statistics > disabled()
Returns a shared pointer to a disabled (no-op) Statistics instance.
AllGather communication service.
AllGather & operator=(AllGather const &)=delete
Deleted copy assignment operator.
AllGather(AllGather &&)=delete
Deleted move constructor.
~AllGather() noexcept
Destructor.
Ordered
Tag requesting ordering for extraction.
@ YES
Extraction is ordered.
@ NO
Extraction is unordered.
void insert(std::uint64_t sequence_number, PackedData &&packed_data)
Insert packed data into the allgather operation.
AllGather & operator=(AllGather &&)=delete
Deleted move assignment operator.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this AllGather.
std::vector< PackedData > wait_and_extract(Ordered ordered=Ordered::YES, std::chrono::milliseconds timeout=std::chrono::milliseconds{-1})
Wait for completion and extract all gathered data.
ProgressThread::ProgressState event_loop()
Main event loop for processing allgather operations.
AllGather(std::shared_ptr< Communicator > comm, OpID op_id, BufferResource *br, std::shared_ptr< Statistics > statistics=Statistics::disabled(), std::function< void(void)> &&finished_callback=nullptr)
Construct a new allgather operation.
AllGather(AllGather const &)=delete
Deleted copy constructor.
void insert_finished()
Mark that this rank has finished contributing data.
Collective communication interfaces.
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
Bag of bytes with metadata suitable for sending over the wire.