allgather.hpp
1 
6 #pragma once
7 
8 #include <cstdint>
9 #include <memory>
10 #include <vector>
11 
12 #include <rapidsmpf/allgather/allgather.hpp>
13 #include <rapidsmpf/buffer/packed_data.hpp>
14 #include <rapidsmpf/communicator/communicator.hpp>
15 #include <rapidsmpf/streaming/chunks/packed_data.hpp>
16 #include <rapidsmpf/streaming/core/channel.hpp>
17 #include <rapidsmpf/streaming/core/context.hpp>
18 
19 #include <coro/event.hpp>
20 #include <coro/task.hpp>
21 
22 namespace rapidsmpf::streaming {
23 
32 class AllGather {
33  public:
42  AllGather(std::shared_ptr<Context> ctx, OpID op_id);
43 
44  AllGather(AllGather const&) = delete;
45  AllGather& operator=(AllGather const&) = delete;
46  AllGather(AllGather&&) = delete;
47  AllGather& operator=(AllGather&&) = delete;
48 
49  ~AllGather();
50 
56  [[nodiscard]] std::shared_ptr<Context> ctx() const noexcept;
57 
64  void insert(std::uint64_t sequence_number, PackedDataChunk&& chunk);
65 
68 
79  coro::task<std::vector<PackedDataChunk>> extract_all(Ordered ordered = Ordered::YES);
80 
81  private:
82  coro::event
83  event_{};
84  std::shared_ptr<Context> ctx_;
85  allgather::AllGather gatherer_;
86 };
87 
88 namespace node {
89 
110  std::shared_ptr<Context> ctx,
111  std::shared_ptr<Channel> ch_in,
112  std::shared_ptr<Channel> ch_out,
113  OpID op_id,
115 );
116 } // namespace node
117 } // namespace rapidsmpf::streaming
AllGather communication service.
Definition: allgather.hpp:372
Ordered
Tag requesting ordering for extraction.
Definition: allgather.hpp:396
Asynchronous (coroutine) interface to allgather::AllGather.
Definition: allgather.hpp:32
void insert_finished()
Mark that this rank has finished contributing data.
std::shared_ptr< Context > ctx() const noexcept
Gets the streaming context associated with this AllGather object.
AllGather(std::shared_ptr< Context > ctx, OpID op_id)
Construct an asynchronous allgather.
void insert(std::uint64_t sequence_number, PackedDataChunk &&chunk)
Insert a chunk into the allgather.
coro::task< std::vector< PackedDataChunk > > extract_all(Ordered ordered=Ordered::YES)
Extract all gathered data.
Node allgather(std::shared_ptr< Context > ctx, std::shared_ptr< Channel > ch_in, std::shared_ptr< Channel > ch_out, OpID op_id, AllGather::Ordered ordered=AllGather::Ordered::YES)
Create an allgather node for a single allgather operation.
coro::task< void > Node
Alias for a node in a streaming pipeline.
Definition: node.hpp:18