8 #include <raft/core/comms.hpp>
9 #include <raft/util/cuda_utils.cuh>
17 T* out,
const T* in,
const raft::comms::comms_t& comm, cudaStream_t stream,
int root = 0)
19 comm.reduce(in, out, 1, raft::comms::op_t::SUM, root, stream);
20 ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS,
21 "An error occurred in the distributed operation. This can result from "
29 const raft::comms::comms_t& comm,
32 comm.allreduce(in, out, 1, raft::comms::op_t::SUM, stream);
33 ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS,
34 "An error occurred in the distributed operation. This can result from "
void reduce_single_sum(T *out, const T *in, const raft::comms::comms_t &comm, cudaStream_t stream, int root=0)
Definition: comm_utils.h:16
void allreduce_single_sum(T *out, const T *in, const raft::comms::comms_t &comm, cudaStream_t stream)
Definition: comm_utils.h:27
Definition: comm_utils.h:11