glm.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2022, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include <cuml/linear_model/qn.h>
19 
20 #include <raft/core/handle.hpp>
21 
22 namespace ML {
23 namespace GLM {
24 
41 void olsFit(const raft::handle_t& handle,
42  float* input,
43  size_t n_rows,
44  size_t n_cols,
45  float* labels,
46  float* coef,
47  float* intercept,
48  bool fit_intercept,
49  bool normalize,
50  int algo = 0,
51  float* sample_weight = nullptr);
52 void olsFit(const raft::handle_t& handle,
53  double* input,
54  size_t n_rows,
55  size_t n_cols,
56  double* labels,
57  double* coef,
58  double* intercept,
59  bool fit_intercept,
60  bool normalize,
61  int algo = 0,
62  double* sample_weight = nullptr);
82 void ridgeFit(const raft::handle_t& handle,
83  float* input,
84  size_t n_rows,
85  size_t n_cols,
86  float* labels,
87  float* alpha,
88  int n_alpha,
89  float* coef,
90  float* intercept,
91  bool fit_intercept,
92  bool normalize,
93  int algo = 0,
94  float* sample_weight = nullptr);
95 void ridgeFit(const raft::handle_t& handle,
96  double* input,
97  size_t n_rows,
98  size_t n_cols,
99  double* labels,
100  double* alpha,
101  int n_alpha,
102  double* coef,
103  double* intercept,
104  bool fit_intercept,
105  bool normalize,
106  int algo = 0,
107  double* sample_weight = nullptr);
121 void gemmPredict(const raft::handle_t& handle,
122  const float* input,
123  size_t n_rows,
124  size_t n_cols,
125  const float* coef,
126  float intercept,
127  float* preds);
128 void gemmPredict(const raft::handle_t& handle,
129  const double* input,
130  size_t n_rows,
131  size_t n_cols,
132  const double* coef,
133  double intercept,
134  double* preds);
156 template <typename T, typename I = int>
157 void qnFit(const raft::handle_t& cuml_handle,
158  const qn_params& params,
159  T* X,
160  bool X_col_major,
161  T* y,
162  I N,
163  I D,
164  I C,
165  T* w0,
166  T* f,
167  int* num_iters,
168  T* sample_weight = nullptr,
169  T svr_eps = 0);
170 
193 template <typename T, typename I = int>
194 void qnFitSparse(const raft::handle_t& cuml_handle,
195  const qn_params& params,
196  T* X_values,
197  I* X_cols,
198  I* X_row_ids,
199  I X_nnz,
200  T* y,
201  I N,
202  I D,
203  I C,
204  T* w0,
205  T* f,
206  int* num_iters,
207  T* sample_weight = nullptr,
208  T svr_eps = 0);
209 
225 template <typename T, typename I = int>
226 void qnDecisionFunction(const raft::handle_t& cuml_handle,
227  const qn_params& params,
228  T* X,
229  bool X_col_major,
230  I N,
231  I D,
232  I C,
233  T* coefs,
234  T* scores);
235 
254 template <typename T, typename I = int>
255 void qnDecisionFunctionSparse(const raft::handle_t& cuml_handle,
256  const qn_params& params,
257  T* X_values,
258  I* X_cols,
259  I* X_row_ids,
260  I X_nnz,
261  I N,
262  I D,
263  I C,
264  T* coefs,
265  T* scores);
266 
282 template <typename T, typename I = int>
283 void qnPredict(const raft::handle_t& cuml_handle,
284  const qn_params& params,
285  T* X,
286  bool X_col_major,
287  I N,
288  I D,
289  I C,
290  T* coefs,
291  T* preds);
292 
311 template <typename T, typename I = int>
312 void qnPredictSparse(const raft::handle_t& cuml_handle,
313  const qn_params& params,
314  T* X_values,
315  I* X_cols,
316  I* X_row_ids,
317  I X_nnz,
318  I N,
319  I D,
320  I C,
321  T* coefs,
322  T* preds);
323 
324 } // namespace GLM
325 } // namespace ML
Definition: params.hpp:34
void gemmPredict(const raft::handle_t &handle, const float *input, size_t n_rows, size_t n_cols, const float *coef, float intercept, float *preds)
void olsFit(const raft::handle_t &handle, float *input, size_t n_rows, size_t n_cols, float *labels, float *coef, float *intercept, bool fit_intercept, bool normalize, int algo=0, float *sample_weight=nullptr)
void ridgeFit(const raft::handle_t &handle, float *input, size_t n_rows, size_t n_cols, float *labels, float *alpha, int n_alpha, float *coef, float *intercept, bool fit_intercept, bool normalize, int algo=0, float *sample_weight=nullptr)
void qnDecisionFunctionSparse(const raft::handle_t &cuml_handle, const qn_params &params, T *X_values, I *X_cols, I *X_row_ids, I X_nnz, I N, I D, I C, T *coefs, T *scores)
Obtain the confidence scores of samples.
void qnDecisionFunction(const raft::handle_t &cuml_handle, const qn_params &params, T *X, bool X_col_major, I N, I D, I C, T *coefs, T *scores)
Obtain the confidence scores of samples.
void qnFit(const raft::handle_t &cuml_handle, const qn_params &params, T *X, bool X_col_major, T *y, I N, I D, I C, T *w0, T *f, int *num_iters, T *sample_weight=nullptr, T svr_eps=0)
Fit a GLM using quasi newton methods.
void qnFitSparse(const raft::handle_t &cuml_handle, const qn_params &params, T *X_values, I *X_cols, I *X_row_ids, I X_nnz, T *y, I N, I D, I C, T *w0, T *f, int *num_iters, T *sample_weight=nullptr, T svr_eps=0)
Fit a GLM using quasi newton methods.
void qnPredict(const raft::handle_t &cuml_handle, const qn_params &params, T *X, bool X_col_major, I N, I D, I C, T *coefs, T *preds)
Predict a GLM using quasi newton methods.
void qnPredictSparse(const raft::handle_t &cuml_handle, const qn_params &params, T *X_values, I *X_cols, I *X_row_ids, I X_nnz, I N, I D, I C, T *coefs, T *preds)
Predict a GLM using quasi newton methods.
void normalize(value_t *data, value_idx n, size_t m, cudaStream_t stream)
Definition: utils.h:194
Definition: dbscan.hpp:30
Definition: qn.h:67