ridge_mg.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 #include "glm.hpp"
9 
12 
13 namespace ML {
14 namespace Ridge {
15 namespace opg {
16 
31 void fit(raft::handle_t& handle,
32  std::vector<MLCommon::Matrix::Data<float>*>& input_data,
34  std::vector<MLCommon::Matrix::Data<float>*>& labels,
35  float* alpha,
36  int n_alpha,
37  float* coef,
38  float* intercept,
39  bool fit_intercept,
40  int algo,
41  bool verbose);
42 
43 void fit(raft::handle_t& handle,
44  std::vector<MLCommon::Matrix::Data<double>*>& input_data,
46  std::vector<MLCommon::Matrix::Data<double>*>& labels,
47  double* alpha,
48  int n_alpha,
49  double* coef,
50  double* intercept,
51  bool fit_intercept,
52  int algo,
53  bool verbose);
54 
68 void predict(raft::handle_t& handle,
69  MLCommon::Matrix::RankSizePair** rank_sizes,
70  size_t n_parts,
72  size_t n_rows,
73  size_t n_cols,
74  float* coef,
75  float intercept,
77  bool verbose);
78 
79 void predict(raft::handle_t& handle,
80  MLCommon::Matrix::RankSizePair** rank_sizes,
81  size_t n_parts,
83  size_t n_rows,
84  size_t n_cols,
85  double* coef,
86  double intercept,
88  bool verbose);
89 
90 }; // end namespace opg
91 }; // end namespace Ridge
92 }; // end namespace ML
void predict(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, size_t n_parts, MLCommon::Matrix::Data< float > **input, size_t n_rows, size_t n_cols, float *coef, float intercept, MLCommon::Matrix::Data< float > **preds, bool verbose)
performs MNMG prediction for OLS
void fit(raft::handle_t &handle, std::vector< MLCommon::Matrix::Data< float > * > &input_data, MLCommon::Matrix::PartDescriptor &input_desc, std::vector< MLCommon::Matrix::Data< float > * > &labels, float *alpha, int n_alpha, float *coef, float *intercept, bool fit_intercept, int algo, bool verbose)
performs MNMG fit operation for the ridge regression
Definition: dbscan.hpp:18
This is a helper wrapper around the multi-gpu data blocks owned by a worker. It's design is NOT final...
Definition: data.hpp:18
Definition: part_descriptor.hpp:40
Definition: part_descriptor.hpp:27