cd_mg.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
11 
12 namespace ML {
13 namespace CD {
14 namespace opg {
15 
33 int fit(raft::handle_t& handle,
34  std::vector<MLCommon::Matrix::Data<float>*>& input_data,
36  std::vector<MLCommon::Matrix::Data<float>*>& labels,
37  float* coef,
38  float* intercept,
39  bool fit_intercept,
40  int epochs,
41  float alpha,
42  float l1_ratio,
43  bool shuffle,
44  float tol,
45  bool verbose);
46 
47 int fit(raft::handle_t& handle,
48  std::vector<MLCommon::Matrix::Data<double>*>& input_data,
50  std::vector<MLCommon::Matrix::Data<double>*>& labels,
51  double* coef,
52  double* intercept,
53  bool fit_intercept,
54  int epochs,
55  double alpha,
56  double l1_ratio,
57  bool shuffle,
58  double tol,
59  bool verbose);
60 
74 void predict(raft::handle_t& handle,
75  MLCommon::Matrix::RankSizePair** rank_sizes,
76  size_t n_parts,
78  size_t n_rows,
79  size_t n_cols,
80  float* coef,
81  float intercept,
83  bool verbose);
84 
85 void predict(raft::handle_t& handle,
86  MLCommon::Matrix::RankSizePair** rank_sizes,
87  size_t n_parts,
89  size_t n_rows,
90  size_t n_cols,
91  double* coef,
92  double intercept,
94  bool verbose);
95 
96 }; // end namespace opg
97 }; // namespace CD
98 }; // end namespace ML
int fit(raft::handle_t &handle, std::vector< MLCommon::Matrix::Data< float > * > &input_data, MLCommon::Matrix::PartDescriptor &input_desc, std::vector< MLCommon::Matrix::Data< float > * > &labels, float *coef, float *intercept, bool fit_intercept, int epochs, float alpha, float l1_ratio, bool shuffle, float tol, bool verbose)
performs MNMG fit operation for the ridge regression
void predict(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, size_t n_parts, MLCommon::Matrix::Data< float > **input, size_t n_rows, size_t n_cols, float *coef, float intercept, MLCommon::Matrix::Data< float > **preds, bool verbose)
performs MNMG prediction for OLS
void shuffle(std::vector< math_t > &rand_indices, std::mt19937 &g)
Definition: shuffle.h:24
Definition: dbscan.hpp:18
This is a helper wrapper around the multi-gpu data blocks owned by a worker. It's design is NOT final...
Definition: data.hpp:18
Definition: part_descriptor.hpp:40
Definition: part_descriptor.hpp:27