10 #include <cumlprims/opg/matrix/data.hpp>
11 #include <cumlprims/opg/matrix/part_descriptor.hpp>
30 void fit(raft::handle_t& handle,
31 MLCommon::Matrix::RankSizePair** rank_sizes,
32 std::uint32_t n_parts,
33 MLCommon::Matrix::floatData_t** input,
38 bool flip_signs_based_on_U =
false);
40 void fit(raft::handle_t& handle,
41 MLCommon::Matrix::RankSizePair** rank_sizes,
42 std::uint32_t n_parts,
43 MLCommon::Matrix::doubleData_t** input,
45 double* singular_vals,
48 bool flip_signs_based_on_U =
false);
67 std::vector<MLCommon::Matrix::Data<float>*>& input_data,
68 MLCommon::Matrix::PartDescriptor& input_desc,
69 std::vector<MLCommon::Matrix::Data<float>*>& trans_data,
70 MLCommon::Matrix::PartDescriptor& trans_desc,
73 float* explained_var_ratio,
77 bool flip_signs_based_on_U);
80 std::vector<MLCommon::Matrix::Data<double>*>& input_data,
81 MLCommon::Matrix::PartDescriptor& input_desc,
82 std::vector<MLCommon::Matrix::Data<double>*>& trans_data,
83 MLCommon::Matrix::PartDescriptor& trans_desc,
85 double* explained_var,
86 double* explained_var_ratio,
87 double* singular_vals,
90 bool flip_signs_based_on_U);
104 MLCommon::Matrix::RankSizePair** rank_sizes,
105 std::uint32_t n_parts,
106 MLCommon::Matrix::Data<float>** input,
108 MLCommon::Matrix::Data<float>** trans_input,
113 MLCommon::Matrix::RankSizePair** rank_sizes,
114 std::uint32_t n_parts,
115 MLCommon::Matrix::Data<double>** input,
117 MLCommon::Matrix::Data<double>** trans_input,
133 MLCommon::Matrix::RankSizePair** rank_sizes,
134 std::uint32_t n_parts,
135 MLCommon::Matrix::Data<float>** trans_input,
137 MLCommon::Matrix::Data<float>** input,
142 MLCommon::Matrix::RankSizePair** rank_sizes,
143 std::uint32_t n_parts,
144 MLCommon::Matrix::Data<double>** trans_input,
146 MLCommon::Matrix::Data<double>** input,
Definition: params.hpp:39
void fit(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, std::uint32_t n_parts, MLCommon::Matrix::floatData_t **input, float *components, float *singular_vals, paramsTSVDMG &prms, bool verbose=false, bool flip_signs_based_on_U=false)
performs MNMG fit operation for the tsvd
void inverse_transform(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, std::uint32_t n_parts, MLCommon::Matrix::Data< float > **trans_input, float *components, MLCommon::Matrix::Data< float > **input, paramsTSVDMG &prms, bool verbose)
performs MNMG inverse transform operation for the output.
void fit_transform(raft::handle_t &handle, std::vector< MLCommon::Matrix::Data< float > * > &input_data, MLCommon::Matrix::PartDescriptor &input_desc, std::vector< MLCommon::Matrix::Data< float > * > &trans_data, MLCommon::Matrix::PartDescriptor &trans_desc, float *components, float *explained_var, float *explained_var_ratio, float *singular_vals, paramsTSVDMG &prms, bool verbose, bool flip_signs_based_on_U)
performs MNMG fit and transform operation for the tsvd.
void transform(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, std::uint32_t n_parts, MLCommon::Matrix::Data< float > **input, float *components, MLCommon::Matrix::Data< float > **trans_input, paramsTSVDMG &prms, bool verbose)
performs MNMG transform operation for the tsvd.
Definition: dbscan.hpp:18