qn_mg.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #include <cuml/common/logger.hpp>
7 #include <cuml/linear_model/qn.h>
10 
11 #include <raft/core/comms.hpp>
12 
13 #include <cuda_runtime.h>
14 
15 #include <vector>
16 
17 namespace ML {
18 namespace GLM {
19 namespace opg {
20 
28 template <typename T>
29 std::vector<T> getUniquelabelsMG(const raft::handle_t& handle,
31  std::vector<MLCommon::Matrix::Data<T>*>& labels);
32 
47 template <typename T>
48 void qnFit(raft::handle_t& handle,
49  std::vector<MLCommon::Matrix::Data<T>*>& input_data,
51  std::vector<MLCommon::Matrix::Data<T>*>& labels,
52  T* coef,
53  const qn_params& pams,
54  bool X_col_major,
55  bool standardization,
56  int n_classes,
57  T* f,
58  int* num_iters);
59 
79 template <typename T, typename I>
80 void qnFitSparse(raft::handle_t& handle,
81  std::vector<MLCommon::Matrix::Data<T>*>& input_values,
82  I* input_cols,
83  I* input_row_ids,
84  I X_nnz,
86  std::vector<MLCommon::Matrix::Data<T>*>& labels,
87  T* coef,
88  const qn_params& pams,
89  bool standardization,
90  int n_classes,
91  T* f,
92  int* num_iters);
93 
94 }; // namespace opg
95 }; // namespace GLM
96 }; // namespace ML
void qnFit(raft::handle_t &handle, std::vector< MLCommon::Matrix::Data< T > * > &input_data, MLCommon::Matrix::PartDescriptor &input_desc, std::vector< MLCommon::Matrix::Data< T > * > &labels, T *coef, const qn_params &pams, bool X_col_major, bool standardization, int n_classes, T *f, int *num_iters)
performs MNMG fit operation for the logistic regression using quasi newton methods
void qnFitSparse(raft::handle_t &handle, std::vector< MLCommon::Matrix::Data< T > * > &input_values, I *input_cols, I *input_row_ids, I X_nnz, MLCommon::Matrix::PartDescriptor &input_desc, std::vector< MLCommon::Matrix::Data< T > * > &labels, T *coef, const qn_params &pams, bool standardization, int n_classes, T *f, int *num_iters)
support sparse vectors (Compressed Sparse Row format) for MNMG logistic regression fit using quasi ne...
std::vector< T > getUniquelabelsMG(const raft::handle_t &handle, MLCommon::Matrix::PartDescriptor &input_desc, std::vector< MLCommon::Matrix::Data< T > * > &labels)
Calculate unique class labels across multiple GPUs in a multi-node environment.
Definition: dbscan.hpp:18
This is a helper wrapper around the multi-gpu data blocks owned by a worker. It's design is NOT final...
Definition: data.hpp:18
Definition: part_descriptor.hpp:40
Definition: qn.h:56