mean_center.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
7 #include "../matrix/data.hpp"
8 #include "../matrix/part_descriptor.hpp"
9 
10 #include <raft/core/comms.hpp>
11 
12 namespace MLCommon {
13 namespace Stats {
14 namespace opg {
15 
25 void mean_center(const std::vector<Matrix::Data<double>*>& data,
26  const Matrix::PartDescriptor& dataDesc,
27  const Matrix::Data<double>& mu,
28  const raft::comms::comms_t& comm,
29  cudaStream_t* streams,
30  int n_streams);
31 
32 void mean_center(const std::vector<Matrix::Data<float>*>& data,
33  const Matrix::PartDescriptor& dataDesc,
34  const Matrix::Data<float>& mu,
35  const raft::comms::comms_t& comm,
36  cudaStream_t* streams,
37  int n_streams);
38 
48 void mean_add(const std::vector<Matrix::Data<double>*>& data,
49  const Matrix::PartDescriptor& dataDesc,
50  const Matrix::Data<double>& mu,
51  const raft::comms::comms_t& comm,
52  cudaStream_t* streams,
53  int n_streams);
54 
55 void mean_add(const std::vector<Matrix::Data<float>*>& data,
56  const Matrix::PartDescriptor& dataDesc,
57  const Matrix::Data<float>& mu,
58  const raft::comms::comms_t& comm,
59  cudaStream_t* streams,
60  int n_streams);
61 
62 } // end namespace opg
63 } // end namespace Stats
64 } // end namespace MLCommon
void mean_center(const std::vector< Matrix::Data< double > * > &data, const Matrix::PartDescriptor &dataDesc, const Matrix::Data< double > &mu, const raft::comms::comms_t &comm, cudaStream_t *streams, int n_streams)
performs MNMG mean subtraction calculation.
void mean_add(const std::vector< Matrix::Data< double > * > &data, const Matrix::PartDescriptor &dataDesc, const Matrix::Data< double > &mu, const raft::comms::comms_t &comm, cudaStream_t *streams, int n_streams)
performs MNMG mean add calculation.
Definition: comm_utils.h:11
This is a helper wrapper around the multi-gpu data blocks owned by a worker. It's design is NOT final...
Definition: data.hpp:18
Definition: part_descriptor.hpp:40