tsvd_mg.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include "tsvd.hpp"
20 
21 #include <cumlprims/opg/matrix/data.hpp>
22 #include <cumlprims/opg/matrix/part_descriptor.hpp>
23 
24 namespace ML {
25 namespace TSVD {
26 namespace opg {
27 
39 void fit(raft::handle_t& handle,
40  MLCommon::Matrix::RankSizePair** rank_sizes,
41  std::uint32_t n_parts,
42  MLCommon::Matrix::floatData_t** input,
43  float* components,
44  float* singular_vals,
45  paramsTSVDMG& prms,
46  bool verbose = false);
47 
48 void fit(raft::handle_t& handle,
49  MLCommon::Matrix::RankSizePair** rank_sizes,
50  std::uint32_t n_parts,
51  MLCommon::Matrix::doubleData_t** input,
52  double* components,
53  double* singular_vals,
54  paramsTSVDMG& prms,
55  bool verbose = false);
56 
71 void fit_transform(raft::handle_t& handle,
72  std::vector<MLCommon::Matrix::Data<float>*>& input_data,
73  MLCommon::Matrix::PartDescriptor& input_desc,
74  std::vector<MLCommon::Matrix::Data<float>*>& trans_data,
75  MLCommon::Matrix::PartDescriptor& trans_desc,
76  float* components,
77  float* explained_var,
78  float* explained_var_ratio,
79  float* singular_vals,
80  paramsTSVDMG& prms,
81  bool verbose);
82 
83 void fit_transform(raft::handle_t& handle,
84  std::vector<MLCommon::Matrix::Data<double>*>& input_data,
85  MLCommon::Matrix::PartDescriptor& input_desc,
86  std::vector<MLCommon::Matrix::Data<double>*>& trans_data,
87  MLCommon::Matrix::PartDescriptor& trans_desc,
88  double* components,
89  double* explained_var,
90  double* explained_var_ratio,
91  double* singular_vals,
92  paramsTSVDMG& prms,
93  bool verbose);
94 
106 void transform(raft::handle_t& handle,
107  MLCommon::Matrix::RankSizePair** rank_sizes,
108  std::uint32_t n_parts,
109  MLCommon::Matrix::Data<float>** input,
110  float* components,
111  MLCommon::Matrix::Data<float>** trans_input,
112  paramsTSVDMG& prms,
113  bool verbose);
114 
115 void transform(raft::handle_t& handle,
116  MLCommon::Matrix::RankSizePair** rank_sizes,
117  std::uint32_t n_parts,
118  MLCommon::Matrix::Data<double>** input,
119  double* components,
120  MLCommon::Matrix::Data<double>** trans_input,
121  paramsTSVDMG& prms,
122  bool verbose);
123 
135 void inverse_transform(raft::handle_t& handle,
136  MLCommon::Matrix::RankSizePair** rank_sizes,
137  std::uint32_t n_parts,
138  MLCommon::Matrix::Data<float>** trans_input,
139  float* components,
140  MLCommon::Matrix::Data<float>** input,
141  paramsTSVDMG& prms,
142  bool verbose);
143 
144 void inverse_transform(raft::handle_t& handle,
145  MLCommon::Matrix::RankSizePair** rank_sizes,
146  std::uint32_t n_parts,
147  MLCommon::Matrix::Data<double>** trans_input,
148  double* components,
149  MLCommon::Matrix::Data<double>** input,
150  paramsTSVDMG& prms,
151  bool verbose);
152 
153 }; // end namespace opg
154 }; // namespace TSVD
155 }; // end namespace ML
Definition: params.hpp:50
void fit(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, std::uint32_t n_parts, MLCommon::Matrix::floatData_t **input, float *components, float *singular_vals, paramsTSVDMG &prms, bool verbose=false)
performs MNMG fit operation for the tsvd
void inverse_transform(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, std::uint32_t n_parts, MLCommon::Matrix::Data< float > **trans_input, float *components, MLCommon::Matrix::Data< float > **input, paramsTSVDMG &prms, bool verbose)
performs MNMG inverse transform operation for the output.
void transform(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, std::uint32_t n_parts, MLCommon::Matrix::Data< float > **input, float *components, MLCommon::Matrix::Data< float > **trans_input, paramsTSVDMG &prms, bool verbose)
performs MNMG transform operation for the tsvd.
void fit_transform(raft::handle_t &handle, std::vector< MLCommon::Matrix::Data< float > * > &input_data, MLCommon::Matrix::PartDescriptor &input_desc, std::vector< MLCommon::Matrix::Data< float > * > &trans_data, MLCommon::Matrix::PartDescriptor &trans_desc, float *components, float *explained_var, float *explained_var_ratio, float *singular_vals, paramsTSVDMG &prms, bool verbose)
performs MNMG fit and transform operation for the tsvd.
Definition: dbscan.hpp:30