allreduce.hpp
1 
6 #pragma once
7 
8 #include <memory>
9 
10 #include <coro/event.hpp>
11 #include <coro/task.hpp>
12 
13 #include <rapidsmpf/coll/allreduce.hpp>
14 #include <rapidsmpf/communicator/communicator.hpp>
15 #include <rapidsmpf/streaming/core/context.hpp>
16 
17 namespace rapidsmpf::streaming {
18 
25 class AllReduce {
26  public:
40  std::shared_ptr<Context> ctx,
41  std::shared_ptr<Communicator> comm,
42  std::unique_ptr<Buffer> input,
43  std::unique_ptr<Buffer> output,
44  OpID op_id,
45  coll::ReduceOperator reduce_operator
46  );
47 
48  AllReduce(AllReduce const&) = delete;
49  AllReduce& operator=(AllReduce const&) = delete;
50  AllReduce(AllReduce&&) = delete;
51  AllReduce& operator=(AllReduce&&) = delete;
52 
53  ~AllReduce() noexcept;
54 
60  [[nodiscard]] std::shared_ptr<Context> const& ctx() const noexcept;
61 
67  [[nodiscard]] std::shared_ptr<Communicator> const& comm() const noexcept;
68 
77  coro::task<std::pair<std::unique_ptr<Buffer>, std::unique_ptr<Buffer>>> extract();
78 
79  private:
80  coro::event
81  event_{};
82  std::shared_ptr<Context> ctx_;
83  coll::AllReduce reducer_;
84 };
85 
86 } // namespace rapidsmpf::streaming
Buffer representing device or host memory.
Definition: buffer.hpp:47
Abstract base class for a communication mechanism between nodes.
AllReduce collective.
Definition: allreduce.hpp:67
Asynchronous (coroutine) interface to coll::AllReduce.
Definition: allreduce.hpp:25
coro::task< std::pair< std::unique_ptr< Buffer >, std::unique_ptr< Buffer > > > extract()
Wait for completion and extract the reduced data.
AllReduce(std::shared_ptr< Context > ctx, std::shared_ptr< Communicator > comm, std::unique_ptr< Buffer > input, std::unique_ptr< Buffer > output, OpID op_id, coll::ReduceOperator reduce_operator)
Construct an asynchronous allreduce.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this AllReduce.
std::shared_ptr< Context > const & ctx() const noexcept
Gets the streaming context associated with this AllReduce object.
Context for actors (coroutines) in rapidsmpf.
Definition: context.hpp:41
std::function< void(Buffer const *left, Buffer *right)> ReduceOperator
Type alias for the reduction function signature.
Definition: allreduce.hpp:40
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...