allreduce.hpp
1 
5 #pragma once
6 
7 #include <algorithm>
8 #include <atomic>
9 #include <chrono>
10 #include <condition_variable>
11 #include <cstdint>
12 #include <functional>
13 #include <memory>
14 #include <mutex>
15 #include <utility>
16 
17 #ifdef __CUDACC__
18 #include <thrust/execution_policy.h>
19 #include <thrust/transform.h>
20 #endif
21 
22 #include <rapidsmpf/communicator/communicator.hpp>
23 #include <rapidsmpf/error.hpp>
24 #include <rapidsmpf/memory/buffer.hpp>
25 #include <rapidsmpf/progress_thread.hpp>
26 #include <rapidsmpf/statistics.hpp>
27 
28 namespace rapidsmpf::coll {
29 
40 using ReduceOperator = std::function<void(Buffer const* left, Buffer* right)>;
41 
67 class AllReduce {
68  public:
89  std::shared_ptr<Communicator> comm,
90  std::unique_ptr<Buffer> input,
91  std::unique_ptr<Buffer> output,
92  OpID op_id,
93  ReduceOperator reduce_operator,
94  std::function<void(void)> finished_callback = nullptr
95  );
96 
102  [[nodiscard]] std::shared_ptr<Communicator> const& comm() const noexcept {
103  return comm_;
104  }
105 
106  AllReduce(AllReduce const&) = delete;
107  AllReduce& operator=(AllReduce const&) = delete;
108  AllReduce(AllReduce&&) = delete;
109  AllReduce& operator=(AllReduce&&) = delete;
110 
118  ~AllReduce() noexcept;
119 
126  [[nodiscard]] bool finished() const noexcept;
127 
157  [[nodiscard]] std::pair<std::unique_ptr<Buffer>, std::unique_ptr<Buffer>>
158  wait_and_extract(std::chrono::milliseconds timeout = std::chrono::milliseconds{-1});
159 
170  [[nodiscard]] bool is_ready() const noexcept;
171 
172  private:
173  enum class Phase : uint8_t {
174  StartPreRemainder,
175  CompletePreRemainder,
176  StartButterfly,
177  CompleteButterfly,
178  StartPostRemainder,
179  CompletePostRemainder,
180  Done,
181  ResultAvailable
182  };
183 
185  [[nodiscard]] ProgressThread::ProgressState event_loop();
186 
187  std::shared_ptr<Communicator> comm_{};
188  ReduceOperator reduce_operator_;
189  std::unique_ptr<Buffer> in_buffer_{};
190  std::unique_ptr<Buffer> out_buffer_{};
191  OpID op_id_{};
192  std::atomic<Phase> phase_{Phase::StartPreRemainder};
193  std::atomic<bool> active_{true};
194  std::function<void()>
195  finished_callback_;
196 
197  mutable std::mutex mutex_;
198  mutable std::condition_variable cv_;
199 
200  Rank logical_rank_{-1};
201  Rank nearest_pow2_{0};
202  Rank non_pow2_remainder_{0};
203  Rank stage_mask_{1};
204  Rank stage_partner_{-1};
205 
206  ProgressThread::FunctionID function_id_{};
207 
208  std::unique_ptr<Communicator::Future> send_future_{};
209  std::unique_ptr<Communicator::Future> recv_future_{};
210 };
211 
212 namespace detail {
213 
222 template <typename T, typename Op>
223 struct HostOp {
224  Op op;
225 
232  void operator()(Buffer const* left, Buffer* right) {
233  auto const left_nbytes = left->size;
234  RAPIDSMPF_EXPECTS(
235  left_nbytes % sizeof(T) == 0,
236  "HostOp buffer size must be a multiple of sizeof(T)"
237  );
238 
239  auto const count = left_nbytes / sizeof(T);
240  if (count == 0) {
241  return;
242  }
243 
244  RAPIDSMPF_EXPECTS(
245  left->mem_type() == MemoryType::HOST && right->mem_type() == MemoryType::HOST,
246  "HostOp expects host memory"
247  );
248 
249  auto* left_bytes = left->data();
250  auto* right_bytes = right->exclusive_data_access();
251 
252  std::span<T const> left_span{reinterpret_cast<T const*>(left_bytes), count};
253  std::span<T> right_span{reinterpret_cast<T*>(right_bytes), count};
254 
255  std::ranges::transform(left_span, right_span, right_span.begin(), op);
256  right->unlock();
257  }
258 };
259 
271 template <typename T, typename Op>
272 struct DeviceOp {
273  Op op;
274 
281  void operator()(Buffer const* left, Buffer* right) {
282 #ifdef __CUDACC__
283  auto const left_nbytes = left->size;
284  RAPIDSMPF_EXPECTS(
285  left_nbytes % sizeof(T) == 0,
286  "DeviceOp buffer size must be a multiple of sizeof(T)"
287  );
288 
289  auto const count = left_nbytes / sizeof(T);
290  if (count == 0) {
291  return;
292  }
293 
294  RAPIDSMPF_EXPECTS(
295  left->mem_type() == MemoryType::DEVICE
296  && right->mem_type() == MemoryType::DEVICE,
297  "DeviceOp expects device memory"
298  );
299  // Both buffers are guaranteed to be on the same stream by the AllReduce ctor.
300  right->write_access([&](std::byte* right_bytes, rmm::cuda_stream_view stream) {
301  auto const* left_bytes = reinterpret_cast<std::byte const*>(left->data());
302 
303  T* right_ptr = reinterpret_cast<T*>(right_bytes);
304  T const* left_ptr = reinterpret_cast<T const*>(left_bytes);
305 
306  thrust::transform(
307  thrust::cuda::par_nosync.on(stream.value()),
308  left_ptr,
309  left_ptr + count,
310  right_ptr,
311  right_ptr,
312  op
313  );
314  });
315 #else
316  // This should never be reached if DeviceOp is only instantiated with CUDA
317  std::ignore = left;
318  std::ignore = right;
319  RAPIDSMPF_FAIL(
320  "DeviceOp::operator() called but CUDA compilation (__CUDACC__) "
321  "was not available. DeviceOp requires CUDA/thrust support.",
322  std::runtime_error
323  );
324 #endif
325  }
326 };
327 
336 template <typename T, typename Op>
337  requires std::invocable<Op, T const&, T const&>
339  return HostOp<T, Op>{std::move(op)};
340 }
341 
353 template <typename T, typename Op>
354  requires std::invocable<Op, T const&, T const&>
356 #ifdef __CUDACC__
357  return DeviceOp<T, Op>{std::move(op)};
358 #else
359  std::ignore = op;
360 
361  RAPIDSMPF_FAIL(
362  "make_device_reduce_operator was called from code that was not compiled "
363  "with NVCC (__CUDACC__ is not defined).",
364  std::runtime_error
365  );
366 #endif
367 }
368 
369 } // namespace detail
370 
371 } // namespace rapidsmpf::coll
Buffer representing device or host memory.
Definition: buffer.hpp:47
auto write_access(F &&f) -> std::invoke_result_t< F, std::byte *, rmm::cuda_stream_view >
Provides stream-ordered write access to the buffer.
Definition: buffer.hpp:129
std::byte * exclusive_data_access()
Acquire non-stream-ordered exclusive access to the buffer's memory.
std::byte const * data() const
Access the underlying memory buffer (host or device memory).
std::size_t const size
The size of the buffer in bytes.
Definition: buffer.hpp:366
void unlock()
Release the exclusive lock acquired by exclusive_data_access().
constexpr MemoryType mem_type() const
Get the memory type of the buffer.
Definition: buffer.hpp:192
ProgressState
The progress state of a function, can be either InProgress or Done.
AllReduce collective.
Definition: allreduce.hpp:67
~AllReduce() noexcept
Destructor.
bool finished() const noexcept
Check if the allreduce operation has completed.
std::pair< std::unique_ptr< Buffer >, std::unique_ptr< Buffer > > wait_and_extract(std::chrono::milliseconds timeout=std::chrono::milliseconds{-1})
Wait for completion and extract the reduced data.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this AllReduce.
Definition: allreduce.hpp:102
AllReduce(std::shared_ptr< Communicator > comm, std::unique_ptr< Buffer > input, std::unique_ptr< Buffer > output, OpID op_id, ReduceOperator reduce_operator, std::function< void(void)> finished_callback=nullptr)
Construct a new AllReduce operation.
bool is_ready() const noexcept
Check if reduced results are ready for extraction.
cudaStream_t value() const noexcept
requires std::invocable< Op, T const &, T const & > ReduceOperator make_device_reduce_operator(Op op)
Create a device-based reduction operator from a typed binary operation.
Definition: allreduce.hpp:355
requires std::invocable< Op, T const &, T const & > ReduceOperator make_host_reduce_operator(Op op)
Create a host-based reduction operator from a typed binary operation.
Definition: allreduce.hpp:338
Collective communication interfaces.
std::function< void(Buffer const *left, Buffer *right)> ReduceOperator
Type alias for the reduction function signature.
Definition: allreduce.hpp:40
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
@ HOST
Host memory.
@ DEVICE
Device memory.
Device-side range-based reduction operator.
Definition: allreduce.hpp:272
void operator()(Buffer const *left, Buffer *right)
Apply the reduction operator to the packed data ranges.
Definition: allreduce.hpp:281
Op op
The binary reduction operator.
Definition: allreduce.hpp:273
Host-side range-based reduction operator.
Definition: allreduce.hpp:223
void operator()(Buffer const *left, Buffer *right)
Apply the reduction operator to the packed data ranges.
Definition: allreduce.hpp:232
Op op
The binary reduction operator.
Definition: allreduce.hpp:224