glm.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2018-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
7 #include <cuml/linear_model/qn.h>
8 
9 #include <raft/core/handle.hpp>
10 
11 namespace ML {
12 namespace GLM {
13 
29 void olsFit(const raft::handle_t& handle,
30  float* input,
31  size_t n_rows,
32  size_t n_cols,
33  float* labels,
34  float* coef,
35  float* intercept,
36  bool fit_intercept,
37  int algo = 0,
38  float* sample_weight = nullptr);
39 void olsFit(const raft::handle_t& handle,
40  double* input,
41  size_t n_rows,
42  size_t n_cols,
43  double* labels,
44  double* coef,
45  double* intercept,
46  bool fit_intercept,
47  int algo = 0,
48  double* sample_weight = nullptr);
67 void ridgeFit(const raft::handle_t& handle,
68  float* input,
69  size_t n_rows,
70  size_t n_cols,
71  float* labels,
72  float* alpha,
73  int n_alpha,
74  float* coef,
75  float* intercept,
76  bool fit_intercept,
77  int algo = 0,
78  float* sample_weight = nullptr);
79 void ridgeFit(const raft::handle_t& handle,
80  double* input,
81  size_t n_rows,
82  size_t n_cols,
83  double* labels,
84  double* alpha,
85  int n_alpha,
86  double* coef,
87  double* intercept,
88  bool fit_intercept,
89  int algo = 0,
90  double* sample_weight = nullptr);
104 void gemmPredict(const raft::handle_t& handle,
105  const float* input,
106  size_t n_rows,
107  size_t n_cols,
108  const float* coef,
109  float intercept,
110  float* preds);
111 void gemmPredict(const raft::handle_t& handle,
112  const double* input,
113  size_t n_rows,
114  size_t n_cols,
115  const double* coef,
116  double intercept,
117  double* preds);
139 template <typename T, typename I = int>
140 void qnFit(const raft::handle_t& cuml_handle,
141  const qn_params& params,
142  T* X,
143  bool X_col_major,
144  T* y,
145  I N,
146  I D,
147  I C,
148  T* w0,
149  T* f,
150  int* num_iters,
151  T* sample_weight = nullptr,
152  T svr_eps = 0);
153 
176 template <typename T, typename I = int>
177 void qnFitSparse(const raft::handle_t& cuml_handle,
178  const qn_params& params,
179  T* X_values,
180  I* X_cols,
181  I* X_row_ids,
182  I X_nnz,
183  T* y,
184  I N,
185  I D,
186  I C,
187  T* w0,
188  T* f,
189  int* num_iters,
190  T* sample_weight = nullptr,
191  T svr_eps = 0);
192 
208 template <typename T, typename I = int>
209 void qnDecisionFunction(const raft::handle_t& cuml_handle,
210  const qn_params& params,
211  T* X,
212  bool X_col_major,
213  I N,
214  I D,
215  I C,
216  T* coefs,
217  T* scores);
218 
237 template <typename T, typename I = int>
238 void qnDecisionFunctionSparse(const raft::handle_t& cuml_handle,
239  const qn_params& params,
240  T* X_values,
241  I* X_cols,
242  I* X_row_ids,
243  I X_nnz,
244  I N,
245  I D,
246  I C,
247  T* coefs,
248  T* scores);
249 
265 template <typename T, typename I = int>
266 void qnPredict(const raft::handle_t& cuml_handle,
267  const qn_params& params,
268  T* X,
269  bool X_col_major,
270  I N,
271  I D,
272  I C,
273  T* coefs,
274  T* preds);
275 
294 template <typename T, typename I = int>
295 void qnPredictSparse(const raft::handle_t& cuml_handle,
296  const qn_params& params,
297  T* X_values,
298  I* X_cols,
299  I* X_row_ids,
300  I X_nnz,
301  I N,
302  I D,
303  I C,
304  T* coefs,
305  T* preds);
306 
307 } // namespace GLM
308 } // namespace ML
Definition: params.hpp:23
void gemmPredict(const raft::handle_t &handle, const float *input, size_t n_rows, size_t n_cols, const float *coef, float intercept, float *preds)
void olsFit(const raft::handle_t &handle, float *input, size_t n_rows, size_t n_cols, float *labels, float *coef, float *intercept, bool fit_intercept, int algo=0, float *sample_weight=nullptr)
void ridgeFit(const raft::handle_t &handle, float *input, size_t n_rows, size_t n_cols, float *labels, float *alpha, int n_alpha, float *coef, float *intercept, bool fit_intercept, int algo=0, float *sample_weight=nullptr)
void qnDecisionFunctionSparse(const raft::handle_t &cuml_handle, const qn_params ¶ms, T *X_values, I *X_cols, I *X_row_ids, I X_nnz, I N, I D, I C, T *coefs, T *scores)
Obtain the confidence scores of samples.
void qnDecisionFunction(const raft::handle_t &cuml_handle, const qn_params ¶ms, T *X, bool X_col_major, I N, I D, I C, T *coefs, T *scores)
Obtain the confidence scores of samples.
void qnFit(const raft::handle_t &cuml_handle, const qn_params ¶ms, T *X, bool X_col_major, T *y, I N, I D, I C, T *w0, T *f, int *num_iters, T *sample_weight=nullptr, T svr_eps=0)
Fit a GLM using quasi newton methods.
void qnFitSparse(const raft::handle_t &cuml_handle, const qn_params ¶ms, T *X_values, I *X_cols, I *X_row_ids, I X_nnz, T *y, I N, I D, I C, T *w0, T *f, int *num_iters, T *sample_weight=nullptr, T svr_eps=0)
Fit a GLM using quasi newton methods.
void qnPredict(const raft::handle_t &cuml_handle, const qn_params ¶ms, T *X, bool X_col_major, I N, I D, I C, T *coefs, T *preds)
Predict a GLM using quasi newton methods.
void qnPredictSparse(const raft::handle_t &cuml_handle, const qn_params ¶ms, T *X_values, I *X_cols, I *X_row_ids, I X_nnz, I N, I D, I C, T *coefs, T *preds)
Predict a GLM using quasi newton methods.
Definition: dbscan.hpp:18
Definition: qn.h:56