tsvd_mg.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 #include "tsvd.hpp"
9 
10 #include <cumlprims/opg/matrix/data.hpp>
11 #include <cumlprims/opg/matrix/part_descriptor.hpp>
12 
13 namespace ML {
14 namespace TSVD {
15 namespace opg {
16 
30 void fit(raft::handle_t& handle,
31  MLCommon::Matrix::RankSizePair** rank_sizes,
32  std::uint32_t n_parts,
33  MLCommon::Matrix::floatData_t** input,
34  float* components,
35  float* singular_vals,
36  paramsTSVDMG& prms,
37  bool verbose = false,
38  bool flip_signs_based_on_U = false);
39 
40 void fit(raft::handle_t& handle,
41  MLCommon::Matrix::RankSizePair** rank_sizes,
42  std::uint32_t n_parts,
43  MLCommon::Matrix::doubleData_t** input,
44  double* components,
45  double* singular_vals,
46  paramsTSVDMG& prms,
47  bool verbose = false,
48  bool flip_signs_based_on_U = false);
49 
66 void fit_transform(raft::handle_t& handle,
67  std::vector<MLCommon::Matrix::Data<float>*>& input_data,
68  MLCommon::Matrix::PartDescriptor& input_desc,
69  std::vector<MLCommon::Matrix::Data<float>*>& trans_data,
70  MLCommon::Matrix::PartDescriptor& trans_desc,
71  float* components,
72  float* explained_var,
73  float* explained_var_ratio,
74  float* singular_vals,
75  paramsTSVDMG& prms,
76  bool verbose,
77  bool flip_signs_based_on_U);
78 
79 void fit_transform(raft::handle_t& handle,
80  std::vector<MLCommon::Matrix::Data<double>*>& input_data,
81  MLCommon::Matrix::PartDescriptor& input_desc,
82  std::vector<MLCommon::Matrix::Data<double>*>& trans_data,
83  MLCommon::Matrix::PartDescriptor& trans_desc,
84  double* components,
85  double* explained_var,
86  double* explained_var_ratio,
87  double* singular_vals,
88  paramsTSVDMG& prms,
89  bool verbose,
90  bool flip_signs_based_on_U);
91 
103 void transform(raft::handle_t& handle,
104  MLCommon::Matrix::RankSizePair** rank_sizes,
105  std::uint32_t n_parts,
106  MLCommon::Matrix::Data<float>** input,
107  float* components,
108  MLCommon::Matrix::Data<float>** trans_input,
109  paramsTSVDMG& prms,
110  bool verbose);
111 
112 void transform(raft::handle_t& handle,
113  MLCommon::Matrix::RankSizePair** rank_sizes,
114  std::uint32_t n_parts,
115  MLCommon::Matrix::Data<double>** input,
116  double* components,
117  MLCommon::Matrix::Data<double>** trans_input,
118  paramsTSVDMG& prms,
119  bool verbose);
120 
132 void inverse_transform(raft::handle_t& handle,
133  MLCommon::Matrix::RankSizePair** rank_sizes,
134  std::uint32_t n_parts,
135  MLCommon::Matrix::Data<float>** trans_input,
136  float* components,
137  MLCommon::Matrix::Data<float>** input,
138  paramsTSVDMG& prms,
139  bool verbose);
140 
141 void inverse_transform(raft::handle_t& handle,
142  MLCommon::Matrix::RankSizePair** rank_sizes,
143  std::uint32_t n_parts,
144  MLCommon::Matrix::Data<double>** trans_input,
145  double* components,
146  MLCommon::Matrix::Data<double>** input,
147  paramsTSVDMG& prms,
148  bool verbose);
149 
150 }; // end namespace opg
151 }; // namespace TSVD
152 }; // end namespace ML
Definition: params.hpp:39
void fit(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, std::uint32_t n_parts, MLCommon::Matrix::floatData_t **input, float *components, float *singular_vals, paramsTSVDMG &prms, bool verbose=false, bool flip_signs_based_on_U=false)
performs MNMG fit operation for the tsvd
void inverse_transform(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, std::uint32_t n_parts, MLCommon::Matrix::Data< float > **trans_input, float *components, MLCommon::Matrix::Data< float > **input, paramsTSVDMG &prms, bool verbose)
performs MNMG inverse transform operation for the output.
void fit_transform(raft::handle_t &handle, std::vector< MLCommon::Matrix::Data< float > * > &input_data, MLCommon::Matrix::PartDescriptor &input_desc, std::vector< MLCommon::Matrix::Data< float > * > &trans_data, MLCommon::Matrix::PartDescriptor &trans_desc, float *components, float *explained_var, float *explained_var_ratio, float *singular_vals, paramsTSVDMG &prms, bool verbose, bool flip_signs_based_on_U)
performs MNMG fit and transform operation for the tsvd.
void transform(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, std::uint32_t n_parts, MLCommon::Matrix::Data< float > **input, float *components, MLCommon::Matrix::Data< float > **trans_input, paramsTSVDMG &prms, bool verbose)
performs MNMG transform operation for the tsvd.
Definition: dbscan.hpp:18