allreduce.hpp
1 
5 #pragma once
6 
7 #include <algorithm>
8 #include <atomic>
9 #include <chrono>
10 #include <concepts>
11 #include <condition_variable>
12 #include <cstdint>
13 #include <functional>
14 #include <memory>
15 #include <mutex>
16 #include <span>
17 #include <utility>
18 
19 #ifdef __CUDACC__
20 #include <thrust/execution_policy.h>
21 #include <thrust/transform.h>
22 #endif
23 
24 #include <rapidsmpf/communicator/communicator.hpp>
25 #include <rapidsmpf/error.hpp>
26 #include <rapidsmpf/memory/buffer.hpp>
27 #include <rapidsmpf/progress_thread.hpp>
28 #include <rapidsmpf/statistics.hpp>
29 
30 namespace rapidsmpf::coll {
31 
39 using ReduceOperator = std::function<void(Buffer const* left, Buffer* right)>;
40 
66 class AllReduce {
67  public:
88  std::shared_ptr<Communicator> comm,
89  std::unique_ptr<Buffer> input,
90  std::unique_ptr<Buffer> output,
91  OpID op_id,
92  ReduceOperator reduce_operator,
93  std::function<void(void)> finished_callback = nullptr
94  );
95 
101  [[nodiscard]] std::shared_ptr<Communicator> const& comm() const noexcept {
102  return comm_;
103  }
104 
105  AllReduce(AllReduce const&) = delete;
106  AllReduce& operator=(AllReduce const&) = delete;
107  AllReduce(AllReduce&&) = delete;
108  AllReduce& operator=(AllReduce&&) = delete;
109 
117  ~AllReduce() noexcept;
118 
125  [[nodiscard]] bool finished() const noexcept;
126 
156  [[nodiscard]] std::pair<std::unique_ptr<Buffer>, std::unique_ptr<Buffer>>
157  wait_and_extract(std::chrono::milliseconds timeout = std::chrono::milliseconds{-1});
158 
169  [[nodiscard]] bool is_ready() const noexcept;
170 
171  private:
172  enum class Phase : uint8_t {
173  StartPreRemainder,
174  CompletePreRemainder,
175  StartButterfly,
176  CompleteButterfly,
177  StartPostRemainder,
178  CompletePostRemainder,
179  Done,
180  ResultAvailable
181  };
182 
184  [[nodiscard]] ProgressThread::ProgressState event_loop();
185 
186  std::shared_ptr<Communicator> comm_{};
187  ReduceOperator reduce_operator_;
188  std::unique_ptr<Buffer> in_buffer_{};
189  std::unique_ptr<Buffer> out_buffer_{};
190  OpID op_id_{};
191  std::atomic<Phase> phase_{Phase::StartPreRemainder};
192  std::atomic<bool> active_{true};
193  std::function<void()>
194  finished_callback_;
195 
196  mutable std::mutex mutex_;
197  mutable std::condition_variable cv_;
198 
199  Rank logical_rank_{-1};
200  Rank nearest_pow2_{0};
201  Rank non_pow2_remainder_{0};
202  Rank stage_mask_{1};
203  Rank stage_partner_{-1};
204 
205  ProgressThread::FunctionID function_id_{};
206 
207  std::unique_ptr<Communicator::Future> send_future_{};
208  std::unique_ptr<Communicator::Future> recv_future_{};
209 };
210 
211 namespace detail {
212 
221 template <typename T, typename Op>
222 struct HostOp {
223  Op op;
224 
231  void operator()(Buffer const* left, Buffer* right) {
232  auto const left_nbytes = left->size;
233  RAPIDSMPF_EXPECTS(
234  left_nbytes % sizeof(T) == 0,
235  "HostOp buffer size must be a multiple of sizeof(T)"
236  );
237 
238  auto const count = left_nbytes / sizeof(T);
239  if (count == 0) {
240  return;
241  }
242 
243  RAPIDSMPF_EXPECTS(
244  left->mem_type() == MemoryType::HOST && right->mem_type() == MemoryType::HOST,
245  "HostOp expects host memory"
246  );
247 
248  auto* left_bytes = left->data();
249  auto* right_bytes = right->exclusive_data_access();
250 
251  std::span<T const> left_span{reinterpret_cast<T const*>(left_bytes), count};
252  std::span<T> right_span{reinterpret_cast<T*>(right_bytes), count};
253 
254  std::ranges::transform(left_span, right_span, right_span.begin(), op);
255  right->unlock();
256  }
257 };
258 
270 template <typename T, typename Op>
271 struct DeviceOp {
272  Op op;
273 
280  void operator()(Buffer const* left, Buffer* right) {
281 #ifdef __CUDACC__
282  auto const left_nbytes = left->size;
283  RAPIDSMPF_EXPECTS(
284  left_nbytes % sizeof(T) == 0,
285  "DeviceOp buffer size must be a multiple of sizeof(T)"
286  );
287 
288  auto const count = left_nbytes / sizeof(T);
289  if (count == 0) {
290  return;
291  }
292 
293  RAPIDSMPF_EXPECTS(
294  left->mem_type() == MemoryType::DEVICE
295  && right->mem_type() == MemoryType::DEVICE,
296  "DeviceOp expects device memory"
297  );
298  // Both buffers are guaranteed to be on the same stream by the AllReduce ctor.
299  right->write_access([&](std::byte* right_bytes, rmm::cuda_stream_view stream) {
300  auto const* left_bytes = reinterpret_cast<std::byte const*>(left->data());
301 
302  T* right_ptr = reinterpret_cast<T*>(right_bytes);
303  T const* left_ptr = reinterpret_cast<T const*>(left_bytes);
304 
305  thrust::transform(
306  thrust::cuda::par_nosync.on(stream.value()),
307  left_ptr,
308  left_ptr + count,
309  right_ptr,
310  right_ptr,
311  op
312  );
313  });
314 #else
315  // This should never be reached if DeviceOp is only instantiated with CUDA
316  std::ignore = left;
317  std::ignore = right;
318  RAPIDSMPF_FAIL(
319  "DeviceOp::operator() called but CUDA compilation (__CUDACC__) "
320  "was not available. DeviceOp requires CUDA/thrust support.",
321  std::runtime_error
322  );
323 #endif
324  }
325 };
326 
335 template <typename T, typename Op>
336  requires std::invocable<Op, T const&, T const&>
337 ReduceOperator make_host_reduce_operator(Op op) {
338  return HostOp<T, Op>{std::move(op)};
339 }
340 
352 template <typename T, typename Op>
353  requires std::invocable<Op, T const&, T const&>
354 ReduceOperator make_device_reduce_operator(Op op) {
355 #ifdef __CUDACC__
356  return DeviceOp<T, Op>{std::move(op)};
357 #else
358  std::ignore = op;
359 
360  RAPIDSMPF_FAIL(
361  "make_device_reduce_operator was called from code that was not compiled "
362  "with NVCC (__CUDACC__ is not defined).",
363  std::runtime_error
364  );
365 #endif
366 }
367 
368 } // namespace detail
369 
370 } // namespace rapidsmpf::coll
Buffer representing device or host memory.
Definition: buffer.hpp:47
auto write_access(F &&f) -> std::invoke_result_t< F, std::byte *, rmm::cuda_stream_view >
Provides stream-ordered write access to the buffer.
Definition: buffer.hpp:130
std::byte * exclusive_data_access()
Acquire non-stream-ordered exclusive access to the buffer's memory.
std::byte const * data() const
Access the underlying memory buffer (host or device memory).
std::size_t const size
The size of the buffer in bytes.
Definition: buffer.hpp:367
void unlock()
Release the exclusive lock acquired by exclusive_data_access().
constexpr MemoryType mem_type() const
Get the memory type of the buffer.
Definition: buffer.hpp:193
ProgressState
The progress state of a function, can be either InProgress or Done.
AllReduce collective.
Definition: allreduce.hpp:66
~AllReduce() noexcept
Destructor.
bool finished() const noexcept
Check if the allreduce operation has completed.
std::pair< std::unique_ptr< Buffer >, std::unique_ptr< Buffer > > wait_and_extract(std::chrono::milliseconds timeout=std::chrono::milliseconds{-1})
Wait for completion and extract the reduced data.
std::shared_ptr< Communicator > const & comm() const noexcept
Gets the communicator associated with this AllReduce.
Definition: allreduce.hpp:101
AllReduce(std::shared_ptr< Communicator > comm, std::unique_ptr< Buffer > input, std::unique_ptr< Buffer > output, OpID op_id, ReduceOperator reduce_operator, std::function< void(void)> finished_callback=nullptr)
Construct a new AllReduce operation.
bool is_ready() const noexcept
Check if reduced results are ready for extraction.
cudaStream_t value() const noexcept
Collective communication interfaces.
std::function< void(Buffer const *left, Buffer *right)> ReduceOperator
Type alias for the reduction function signature.
Definition: allreduce.hpp:39
std::int32_t Rank
The rank of a node (e.g. the rank of a MPI process), or world size (total number of ranks).
std::int32_t OpID
Operation ID defined by the user. This allows users to concurrently execute multiple operations,...
@ HOST
Host memory.
@ DEVICE
Device memory.
Device-side range-based reduction operator.
Definition: allreduce.hpp:271
void operator()(Buffer const *left, Buffer *right)
Apply the reduction operator to the packed data ranges.
Definition: allreduce.hpp:280
Op op
The binary reduction operator.
Definition: allreduce.hpp:272
Host-side range-based reduction operator.
Definition: allreduce.hpp:222
void operator()(Buffer const *left, Buffer *right)
Apply the reduction operator to the packed data ranges.
Definition: allreduce.hpp:231
Op op
The binary reduction operator.
Definition: allreduce.hpp:223