9 #include <condition_variable>
20 #include <rapidsmpf/communicator/communicator.hpp>
21 #include <rapidsmpf/error.hpp>
22 #include <rapidsmpf/memory/buffer.hpp>
23 #include <rapidsmpf/memory/buffer_resource.hpp>
24 #include <rapidsmpf/memory/packed_data.hpp>
25 #include <rapidsmpf/memory/spill_manager.hpp>
26 #include <rapidsmpf/progress_thread.hpp>
27 #include <rapidsmpf/statistics.hpp>
40 using ChunkID = std::uint64_t;
54 std::unique_ptr<std::vector<std::uint8_t>> metadata_;
55 std::unique_ptr<Buffer> data_;
76 std::unique_ptr<std::vector<std::uint8_t>> metadata,
77 std::unique_ptr<Buffer> data
121 [[nodiscard]] ChunkID
id() const noexcept;
172 std::uint64_t num_local_insertions,
Rank origin
191 sizeof(ChunkID) * std::numeric_limits<
unsigned char>::digits -
ID_BITS;
208 [[nodiscard]] std::unique_ptr<std::vector<std::uint8_t>>
serialize() const;
282 void insert(std::unique_ptr<Chunk> chunk);
289 void insert(std::vector<std::unique_ptr<Chunk>>&& chunks);
304 [[nodiscard]]
bool ready() const noexcept;
314 [[nodiscard]] std::vector<std::unique_ptr<
Chunk>> extract_ready();
324 [[nodiscard]] std::vector<std::unique_ptr<
Chunk>> extract();
331 [[nodiscard]]
bool empty() const noexcept;
348 mutable std::mutex mutex_{};
349 std::vector<std::unique_ptr<Chunk>> chunks_{};
350 std::atomic<std::uint64_t> goalpost_{0};
420 Ordered ordered = Ordered::YES,
421 std::chrono::milliseconds timeout = std::chrono::milliseconds{-1}
466 std::shared_ptr<Communicator> comm,
470 std::function<
void(
void)>&& finished_callback =
nullptr
487 [[nodiscard]] std::shared_ptr<Communicator>
const&
comm() const noexcept {
516 void insert(std::unique_ptr<detail::Chunk> chunk);
524 void mark_finish(std::uint64_t expected_chunks) noexcept;
533 void wait(std::chrono::milliseconds timeout = std::chrono::milliseconds{-1});
542 std::size_t spill(std::optional<std::size_t> amount = std::nullopt);
544 std::shared_ptr<Communicator> comm_;
546 std::shared_ptr<Statistics> statistics_;
547 std::function<void(
void)> finished_callback_{
550 std::atomic<Rank> finish_counter_;
551 std::atomic<std::uint32_t> nlocal_insertions_;
553 std::atomic<bool> locally_finished_{
false};
554 std::atomic<bool> active_{
true};
555 bool can_extract_{
false};
556 mutable std::mutex mutex_;
557 std::condition_variable cv_;
558 detail::PostBox inserted_{};
559 detail::PostBox for_extraction_{};
560 ProgressThread::FunctionID function_id_{};
566 Rank remote_finish_counter_;
568 std::uint64_t num_expected_messages_{0};
570 std::uint64_t num_received_messages_{0};
572 std::vector<std::unique_ptr<detail::Chunk>> to_receive_{};
574 std::vector<std::unique_ptr<Communicator::Future>> fire_and_forget_{};
576 std::vector<std::unique_ptr<detail::Chunk>> sent_posted_{};
578 std::vector<std::unique_ptr<Communicator::Future>> sent_futures_{};
580 std::vector<std::unique_ptr<detail::Chunk>> receive_posted_{};
582 std::vector<std::unique_ptr<Communicator::Future>> receive_futures_{};
Class managing buffer resources.
Buffer representing device or host memory.
A progress thread that can execute arbitrary functions.
std::size_t SpillFunctionID
Represents a unique identifier for a registered spill function.
static std::shared_ptr< Statistics > disabled()
Returns a shared pointer to a disabled (no-op) Statistics instance.
AllGather communication service.
AllGather & operator=(AllGather const &)=delete
Deleted copy assignment operator.
AllGather(AllGather &&)=delete
Deleted move constructor.
~AllGather() noexcept
Destructor.
Ordered
Tag requesting ordering for extraction.
void insert(std::uint64_t sequence_number, PackedData &&packed_data)
Insert packed data into the allgather operation.
AllGather & operator=(AllGather &&)=delete
Deleted move assignment operator.
std::vector< PackedData > extract_ready()
Extract any available partitions.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this AllGather.
std::vector< PackedData > wait_and_extract(Ordered ordered=Ordered::YES, std::chrono::milliseconds timeout=std::chrono::milliseconds{-1})
Wait for completion and extract all gathered data.
AllGather(std::shared_ptr< Communicator > comm, OpID op_id, BufferResource *br, std::shared_ptr< Statistics > statistics=Statistics::disabled(), std::function< void(void)> &&finished_callback=nullptr)
Construct a new allgather operation.
AllGather(AllGather const &)=delete
Deleted copy constructor.
bool finished() const noexcept
Check if the allgather operation has completed.
void insert_finished()
Mark that this rank has finished contributing data.
Represents a data chunk in the allgather operation.
bool is_ready() const noexcept
Check if the chunk is ready for processing.
static std::unique_ptr< Chunk > deserialize(std::vector< std::uint8_t > &data, BufferResource *br)
Deserialize a chunk from a byte vector of its metadata.
std::unique_ptr< Buffer > release_data_buffer() noexcept
Release and return the data buffer.
void attach_data_buffer(std::unique_ptr< Buffer > data)
Attach a data buffer to this chunk.
static std::unique_ptr< Chunk > from_packed_data(std::uint64_t sequence, Rank origin, PackedData &&packed_data)
Create a data chunk from packed data.
Rank origin() const noexcept
The origin rank of the chunk.
static constexpr std::uint64_t RANK_BITS
Number of bits used for the rank in the chunk identifier.
std::uint64_t data_size() const noexcept
The size of the data buffer in bytes.
static std::unique_ptr< Chunk > from_empty(std::uint64_t num_local_insertions, Rank origin)
Create an empty finish marker chunk.
std::uint64_t metadata_size() const noexcept
The size of the metadata buffer in bytes.
static constexpr std::uint64_t ID_BITS
Number of bits used for the sequence ID in the chunk identifier.
std::unique_ptr< std::vector< std::uint8_t > > serialize() const
Serialize the metadata of the chunk to a byte vector.
ChunkID sequence() const noexcept
The sequence number of the chunk.
bool is_finish() const noexcept
Check if this is a finish marker chunk.
PackedData release()
Release the chunk's data as PackedData.
MemoryType memory_type() const noexcept
Return the memory type of the chunk.
static constexpr ChunkID chunk_id(std::uint64_t sequence, Rank origin)
Create a ChunkID from a sequence number and origin rank.
A thread-safe container for managing chunks in an AllGather.
void insert(std::unique_ptr< Chunk > chunk)
Insert a single chunk into the postbox.
PostBox & operator=(PostBox &&)=delete
Deleted move assignment operator.
bool ready() const noexcept
Check if the postbox has reached its goal.
~PostBox()=default
Default destructor.
void insert(std::vector< std::unique_ptr< Chunk >> &&chunks)
Insert multiple chunks into the postbox.
PostBox()=default
Default constructor.
PostBox(PostBox &&)=delete
Deleted move constructor.
PostBox & operator=(PostBox const &)=delete
Deleted copy assignment operator.
PostBox(PostBox const &)=delete
Deleted copy constructor.
void increment_goalpost(std::uint64_t amount)
Increment the goalpost to a new expected chunk count.
Collective communication interfaces.
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
@ YES
Overbooking is allowed.
@ NO
Overbooking is not allowed.
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
MemoryType
Enum representing the type of memory sorted in decreasing order of preference.
Bag of bytes with metadata suitable for sending over the wire.