comm_utils.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 #include <raft/core/comms.hpp>
9 #include <raft/util/cuda_utils.cuh>
10 
11 namespace MLCommon {
12 namespace opg {
13 
15 template <typename T>
17  T* out, const T* in, const raft::comms::comms_t& comm, cudaStream_t stream, int root = 0)
18 {
19  comm.reduce(in, out, 1, raft::comms::op_t::SUM, root, stream);
20  ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS,
21  "An error occurred in the distributed operation. This can result from "
22  "a failed rank");
23 }
24 
26 template <typename T>
28  const T* in,
29  const raft::comms::comms_t& comm,
30  cudaStream_t stream)
31 {
32  comm.allreduce(in, out, 1, raft::comms::op_t::SUM, stream);
33  ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS,
34  "An error occurred in the distributed operation. This can result from "
35  "a failed rank");
36 }
37 
38 }; // end namespace opg
39 }; // end namespace MLCommon
void reduce_single_sum(T *out, const T *in, const raft::comms::comms_t &comm, cudaStream_t stream, int root=0)
Definition: comm_utils.h:16
void allreduce_single_sum(T *out, const T *in, const raft::comms::comms_t &comm, cudaStream_t stream)
Definition: comm_utils.h:27
Definition: comm_utils.h:11