Namespaces | Functions
mean_squared_error.hpp File Reference
#include "../matrix/data.hpp"
#include "../matrix/part_descriptor.hpp"
#include <raft/core/comms.hpp>
#include <raft/core/device_mdspan.hpp>
Include dependency graph for mean_squared_error.hpp:

Go to the source code of this file.

Namespaces

 MLCommon
 
 MLCommon::LinAlg
 
 MLCommon::LinAlg::opg
 

Functions

void MLCommon::LinAlg::opg::meanSquaredError (double *out, const Matrix::Data< double > &in1, const Matrix::PartDescriptor &in1Desc, const Matrix::Data< double > &in2, const Matrix::PartDescriptor &in2Desc, const raft::comms::comms_t &comm, cudaStream_t stream, int root=0, bool broadcastResult=true)
 multi-gpu mean squared error More...
 
void MLCommon::LinAlg::opg::meanSquaredError (float *out, const Matrix::Data< float > &in1, const Matrix::PartDescriptor &in1Desc, const Matrix::Data< float > &in2, const Matrix::PartDescriptor &in2Desc, const raft::comms::comms_t &comm, cudaStream_t stream, int root=0, bool broadcastResult=true)