glm.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2018-2022, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
7 #include <cuml/linear_model/qn.h>
8 
9 #include <raft/core/handle.hpp>
10 
11 namespace ML {
12 namespace GLM {
13 
30 void olsFit(const raft::handle_t& handle,
31  float* input,
32  size_t n_rows,
33  size_t n_cols,
34  float* labels,
35  float* coef,
36  float* intercept,
37  bool fit_intercept,
38  bool normalize,
39  int algo = 0,
40  float* sample_weight = nullptr);
41 void olsFit(const raft::handle_t& handle,
42  double* input,
43  size_t n_rows,
44  size_t n_cols,
45  double* labels,
46  double* coef,
47  double* intercept,
48  bool fit_intercept,
49  bool normalize,
50  int algo = 0,
51  double* sample_weight = nullptr);
71 void ridgeFit(const raft::handle_t& handle,
72  float* input,
73  size_t n_rows,
74  size_t n_cols,
75  float* labels,
76  float* alpha,
77  int n_alpha,
78  float* coef,
79  float* intercept,
80  bool fit_intercept,
81  bool normalize,
82  int algo = 0,
83  float* sample_weight = nullptr);
84 void ridgeFit(const raft::handle_t& handle,
85  double* input,
86  size_t n_rows,
87  size_t n_cols,
88  double* labels,
89  double* alpha,
90  int n_alpha,
91  double* coef,
92  double* intercept,
93  bool fit_intercept,
94  bool normalize,
95  int algo = 0,
96  double* sample_weight = nullptr);
110 void gemmPredict(const raft::handle_t& handle,
111  const float* input,
112  size_t n_rows,
113  size_t n_cols,
114  const float* coef,
115  float intercept,
116  float* preds);
117 void gemmPredict(const raft::handle_t& handle,
118  const double* input,
119  size_t n_rows,
120  size_t n_cols,
121  const double* coef,
122  double intercept,
123  double* preds);
145 template <typename T, typename I = int>
146 void qnFit(const raft::handle_t& cuml_handle,
147  const qn_params& params,
148  T* X,
149  bool X_col_major,
150  T* y,
151  I N,
152  I D,
153  I C,
154  T* w0,
155  T* f,
156  int* num_iters,
157  T* sample_weight = nullptr,
158  T svr_eps = 0);
159 
182 template <typename T, typename I = int>
183 void qnFitSparse(const raft::handle_t& cuml_handle,
184  const qn_params& params,
185  T* X_values,
186  I* X_cols,
187  I* X_row_ids,
188  I X_nnz,
189  T* y,
190  I N,
191  I D,
192  I C,
193  T* w0,
194  T* f,
195  int* num_iters,
196  T* sample_weight = nullptr,
197  T svr_eps = 0);
198 
214 template <typename T, typename I = int>
215 void qnDecisionFunction(const raft::handle_t& cuml_handle,
216  const qn_params& params,
217  T* X,
218  bool X_col_major,
219  I N,
220  I D,
221  I C,
222  T* coefs,
223  T* scores);
224 
243 template <typename T, typename I = int>
244 void qnDecisionFunctionSparse(const raft::handle_t& cuml_handle,
245  const qn_params& params,
246  T* X_values,
247  I* X_cols,
248  I* X_row_ids,
249  I X_nnz,
250  I N,
251  I D,
252  I C,
253  T* coefs,
254  T* scores);
255 
271 template <typename T, typename I = int>
272 void qnPredict(const raft::handle_t& cuml_handle,
273  const qn_params& params,
274  T* X,
275  bool X_col_major,
276  I N,
277  I D,
278  I C,
279  T* coefs,
280  T* preds);
281 
300 template <typename T, typename I = int>
301 void qnPredictSparse(const raft::handle_t& cuml_handle,
302  const qn_params& params,
303  T* X_values,
304  I* X_cols,
305  I* X_row_ids,
306  I X_nnz,
307  I N,
308  I D,
309  I C,
310  T* coefs,
311  T* preds);
312 
313 } // namespace GLM
314 } // namespace ML
Definition: params.hpp:23
void gemmPredict(const raft::handle_t &handle, const float *input, size_t n_rows, size_t n_cols, const float *coef, float intercept, float *preds)
void olsFit(const raft::handle_t &handle, float *input, size_t n_rows, size_t n_cols, float *labels, float *coef, float *intercept, bool fit_intercept, bool normalize, int algo=0, float *sample_weight=nullptr)
void ridgeFit(const raft::handle_t &handle, float *input, size_t n_rows, size_t n_cols, float *labels, float *alpha, int n_alpha, float *coef, float *intercept, bool fit_intercept, bool normalize, int algo=0, float *sample_weight=nullptr)
void qnDecisionFunctionSparse(const raft::handle_t &cuml_handle, const qn_params ¶ms, T *X_values, I *X_cols, I *X_row_ids, I X_nnz, I N, I D, I C, T *coefs, T *scores)
Obtain the confidence scores of samples.
void qnDecisionFunction(const raft::handle_t &cuml_handle, const qn_params ¶ms, T *X, bool X_col_major, I N, I D, I C, T *coefs, T *scores)
Obtain the confidence scores of samples.
void qnFit(const raft::handle_t &cuml_handle, const qn_params ¶ms, T *X, bool X_col_major, T *y, I N, I D, I C, T *w0, T *f, int *num_iters, T *sample_weight=nullptr, T svr_eps=0)
Fit a GLM using quasi newton methods.
void qnFitSparse(const raft::handle_t &cuml_handle, const qn_params ¶ms, T *X_values, I *X_cols, I *X_row_ids, I X_nnz, T *y, I N, I D, I C, T *w0, T *f, int *num_iters, T *sample_weight=nullptr, T svr_eps=0)
Fit a GLM using quasi newton methods.
void qnPredict(const raft::handle_t &cuml_handle, const qn_params ¶ms, T *X, bool X_col_major, I N, I D, I C, T *coefs, T *preds)
Predict a GLM using quasi newton methods.
void qnPredictSparse(const raft::handle_t &cuml_handle, const qn_params ¶ms, T *X_values, I *X_cols, I *X_row_ids, I X_nnz, I N, I D, I C, T *coefs, T *preds)
Predict a GLM using quasi newton methods.
void normalize(value_t *data, value_idx n, size_t m, cudaStream_t stream)
Definition: utils.h:177
Definition: dbscan.hpp:18
Definition: qn.h:56