ols_mg.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
11 
12 namespace ML {
13 namespace OLS {
14 namespace opg {
15 
28 void fit(raft::handle_t& handle,
29  std::vector<MLCommon::Matrix::Data<float>*>& input_data,
31  std::vector<MLCommon::Matrix::Data<float>*>& labels,
32  float* coef,
33  float* intercept,
34  bool fit_intercept,
35  int algo,
36  bool verbose);
37 
38 void fit(raft::handle_t& handle,
39  std::vector<MLCommon::Matrix::Data<double>*>& input_data,
41  std::vector<MLCommon::Matrix::Data<double>*>& labels,
42  double* coef,
43  double* intercept,
44  bool fit_intercept,
45  int algo,
46  bool verbose);
47 
61 void predict(raft::handle_t& handle,
62  MLCommon::Matrix::RankSizePair** rank_sizes,
63  size_t n_parts,
65  size_t n_rows,
66  size_t n_cols,
67  float* coef,
68  float intercept,
70  bool verbose);
71 
72 void predict(raft::handle_t& handle,
73  MLCommon::Matrix::RankSizePair** rank_sizes,
74  size_t n_parts,
76  size_t n_rows,
77  size_t n_cols,
78  double* coef,
79  double intercept,
81  bool verbose);
82 
83 }; // end namespace opg
84 }; // end namespace OLS
85 }; // end namespace ML
void fit(raft::handle_t &handle, std::vector< MLCommon::Matrix::Data< float > * > &input_data, MLCommon::Matrix::PartDescriptor &input_desc, std::vector< MLCommon::Matrix::Data< float > * > &labels, float *coef, float *intercept, bool fit_intercept, int algo, bool verbose)
performs MNMG fit operation for the ridge regression
void predict(raft::handle_t &handle, MLCommon::Matrix::RankSizePair **rank_sizes, size_t n_parts, MLCommon::Matrix::Data< float > **input, size_t n_rows, size_t n_cols, float *coef, float intercept, MLCommon::Matrix::Data< float > **preds, bool verbose)
performs MNMG prediction for OLS
Definition: dbscan.hpp:18
This is a helper wrapper around the multi-gpu data blocks owned by a worker. It's design is NOT final...
Definition: data.hpp:18
Definition: part_descriptor.hpp:40
Definition: part_descriptor.hpp:27