mean_squared_error.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
7 #include "../matrix/data.hpp"
8 #include "../matrix/part_descriptor.hpp"
9 
10 #include <raft/core/comms.hpp>
11 #include <raft/core/device_mdspan.hpp>
12 
13 namespace MLCommon {
14 namespace LinAlg {
15 namespace opg {
16 
30 void meanSquaredError(double* out,
31  const Matrix::Data<double>& in1,
32  const Matrix::PartDescriptor& in1Desc,
33  const Matrix::Data<double>& in2,
34  const Matrix::PartDescriptor& in2Desc,
35  const raft::comms::comms_t& comm,
36  cudaStream_t stream,
37  int root = 0,
38  bool broadcastResult = true);
39 void meanSquaredError(float* out,
40  const Matrix::Data<float>& in1,
41  const Matrix::PartDescriptor& in1Desc,
42  const Matrix::Data<float>& in2,
43  const Matrix::PartDescriptor& in2Desc,
44  const raft::comms::comms_t& comm,
45  cudaStream_t stream,
46  int root = 0,
47  bool broadcastResult = true);
48 
49 } // end namespace opg
50 } // end namespace LinAlg
51 } // end namespace MLCommon
void meanSquaredError(double *out, const Matrix::Data< double > &in1, const Matrix::PartDescriptor &in1Desc, const Matrix::Data< double > &in2, const Matrix::PartDescriptor &in2Desc, const raft::comms::comms_t &comm, cudaStream_t stream, int root=0, bool broadcastResult=true)
multi-gpu mean squared error
Definition: comm_utils.h:11
This is a helper wrapper around the multi-gpu data blocks owned by a worker. It's design is NOT final...
Definition: data.hpp:18
Definition: part_descriptor.hpp:40