solver.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2018-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 namespace raft {
9 class handle_t;
10 }
11 
12 namespace ML {
13 namespace Solver {
14 
15 void sgdFit(raft::handle_t& handle,
16  float* input,
17  int n_rows,
18  int n_cols,
19  float* labels,
20  float* coef,
21  float* intercept,
22  bool fit_intercept,
23  int batch_size,
24  int epochs,
25  int lr_type,
26  float eta0,
27  float power_t,
28  int loss,
29  int penalty,
30  float alpha,
31  float l1_ratio,
32  bool shuffle,
33  float tol,
34  int n_iter_no_change);
35 
36 void sgdFit(raft::handle_t& handle,
37  double* input,
38  int n_rows,
39  int n_cols,
40  double* labels,
41  double* coef,
42  double* intercept,
43  bool fit_intercept,
44  int batch_size,
45  int epochs,
46  int lr_type,
47  double eta0,
48  double power_t,
49  int loss,
50  int penalty,
51  double alpha,
52  double l1_ratio,
53  bool shuffle,
54  double tol,
55  int n_iter_no_change);
56 
57 void sgdPredict(raft::handle_t& handle,
58  const float* input,
59  int n_rows,
60  int n_cols,
61  const float* coef,
62  float intercept,
63  float* preds,
64  int loss);
65 
66 void sgdPredict(raft::handle_t& handle,
67  const double* input,
68  int n_rows,
69  int n_cols,
70  const double* coef,
71  double intercept,
72  double* preds,
73  int loss);
74 
75 void sgdPredictBinaryClass(raft::handle_t& handle,
76  const float* input,
77  int n_rows,
78  int n_cols,
79  const float* coef,
80  float intercept,
81  float* preds,
82  int loss);
83 
84 void sgdPredictBinaryClass(raft::handle_t& handle,
85  const double* input,
86  int n_rows,
87  int n_cols,
88  const double* coef,
89  double intercept,
90  double* preds,
91  int loss);
92 
143 int cdFit(raft::handle_t& handle,
144  float* input,
145  int n_rows,
146  int n_cols,
147  float* labels,
148  float* coef,
149  float* intercept,
150  bool fit_intercept,
151  bool normalize,
152  int epochs,
153  int loss,
154  float alpha,
155  float l1_ratio,
156  bool shuffle,
157  float tol,
158  float* sample_weight = nullptr);
159 
160 int cdFit(raft::handle_t& handle,
161  double* input,
162  int n_rows,
163  int n_cols,
164  double* labels,
165  double* coef,
166  double* intercept,
167  bool fit_intercept,
168  bool normalize,
169  int epochs,
170  int loss,
171  double alpha,
172  double l1_ratio,
173  bool shuffle,
174  double tol,
175  double* sample_weight = nullptr);
176 
177 void cdPredict(raft::handle_t& handle,
178  const float* input,
179  int n_rows,
180  int n_cols,
181  const float* coef,
182  float intercept,
183  float* preds,
184  int loss);
185 
186 void cdPredict(raft::handle_t& handle,
187  const double* input,
188  int n_rows,
189  int n_cols,
190  const double* coef,
191  double intercept,
192  double* preds,
193  int loss);
194 
195 }; // namespace Solver
196 }; // end namespace ML
void normalize(value_t *data, value_idx n, size_t m, cudaStream_t stream)
Definition: utils.h:177
void sgdPredictBinaryClass(raft::handle_t &handle, const float *input, int n_rows, int n_cols, const float *coef, float intercept, float *preds, int loss)
void cdPredict(raft::handle_t &handle, const float *input, int n_rows, int n_cols, const float *coef, float intercept, float *preds, int loss)
void sgdFit(raft::handle_t &handle, float *input, int n_rows, int n_cols, float *labels, float *coef, float *intercept, bool fit_intercept, int batch_size, int epochs, int lr_type, float eta0, float power_t, int loss, int penalty, float alpha, float l1_ratio, bool shuffle, float tol, int n_iter_no_change)
int cdFit(raft::handle_t &handle, float *input, int n_rows, int n_cols, float *labels, float *coef, float *intercept, bool fit_intercept, bool normalize, int epochs, int loss, float alpha, float l1_ratio, bool shuffle, float tol, float *sample_weight=nullptr)
void shuffle(std::vector< math_t > &rand_indices, std::mt19937 &g)
Definition: shuffle.h:24
void sgdPredict(raft::handle_t &handle, const float *input, int n_rows, int n_cols, const float *coef, float intercept, float *preds, int loss)
Definition: dbscan.hpp:18
lr_type
Definition: params.hpp:10
penalty
Definition: params.hpp:23
Definition: dbscan.hpp:14