qn_mg.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #include <cuml/common/logger.hpp>
7 #include <cuml/linear_model/qn.h>
8 
9 #include <cumlprims/opg/matrix/data.hpp>
10 #include <cumlprims/opg/matrix/part_descriptor.hpp>
11 #include <raft/core/comms.hpp>
12 
13 #include <cuda_runtime.h>
14 
15 #include <vector>
16 using namespace MLCommon;
17 
18 namespace ML {
19 namespace GLM {
20 namespace opg {
21 
29 template <typename T>
30 std::vector<T> getUniquelabelsMG(const raft::handle_t& handle,
31  Matrix::PartDescriptor& input_desc,
32  std::vector<Matrix::Data<T>*>& labels);
33 
48 template <typename T>
49 void qnFit(raft::handle_t& handle,
50  std::vector<Matrix::Data<T>*>& input_data,
51  Matrix::PartDescriptor& input_desc,
52  std::vector<Matrix::Data<T>*>& labels,
53  T* coef,
54  const qn_params& pams,
55  bool X_col_major,
56  bool standardization,
57  int n_classes,
58  T* f,
59  int* num_iters);
60 
80 template <typename T, typename I>
81 void qnFitSparse(raft::handle_t& handle,
82  std::vector<Matrix::Data<T>*>& input_values,
83  I* input_cols,
84  I* input_row_ids,
85  I X_nnz,
86  Matrix::PartDescriptor& input_desc,
87  std::vector<Matrix::Data<T>*>& labels,
88  T* coef,
89  const qn_params& pams,
90  bool standardization,
91  int n_classes,
92  T* f,
93  int* num_iters);
94 
95 }; // namespace opg
96 }; // namespace GLM
97 }; // namespace ML
Definition: Timer.h:9
std::vector< T > getUniquelabelsMG(const raft::handle_t &handle, Matrix::PartDescriptor &input_desc, std::vector< Matrix::Data< T > * > &labels)
Calculate unique class labels across multiple GPUs in a multi-node environment.
void qnFitSparse(raft::handle_t &handle, std::vector< Matrix::Data< T > * > &input_values, I *input_cols, I *input_row_ids, I X_nnz, Matrix::PartDescriptor &input_desc, std::vector< Matrix::Data< T > * > &labels, T *coef, const qn_params &pams, bool standardization, int n_classes, T *f, int *num_iters)
support sparse vectors (Compressed Sparse Row format) for MNMG logistic regression fit using quasi ne...
void qnFit(raft::handle_t &handle, std::vector< Matrix::Data< T > * > &input_data, Matrix::PartDescriptor &input_desc, std::vector< Matrix::Data< T > * > &labels, T *coef, const qn_params &pams, bool X_col_major, bool standardization, int n_classes, T *f, int *num_iters)
performs MNMG fit operation for the logistic regression using quasi newton methods
Definition: dbscan.hpp:18
Definition: qn.h:56