solver.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2018-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 namespace raft {
9 class handle_t;
10 }
11 
12 namespace ML {
13 namespace Solver {
14 
15 void sgdFit(raft::handle_t& handle,
16  float* input,
17  int n_rows,
18  int n_cols,
19  float* labels,
20  float* coef,
21  float* intercept,
22  bool fit_intercept,
23  int batch_size,
24  int epochs,
25  int lr_type,
26  float eta0,
27  float power_t,
28  int loss,
29  int penalty,
30  float alpha,
31  float l1_ratio,
32  bool shuffle,
33  float tol,
34  int n_iter_no_change);
35 
36 void sgdFit(raft::handle_t& handle,
37  double* input,
38  int n_rows,
39  int n_cols,
40  double* labels,
41  double* coef,
42  double* intercept,
43  bool fit_intercept,
44  int batch_size,
45  int epochs,
46  int lr_type,
47  double eta0,
48  double power_t,
49  int loss,
50  int penalty,
51  double alpha,
52  double l1_ratio,
53  bool shuffle,
54  double tol,
55  int n_iter_no_change);
56 
57 void sgdPredict(raft::handle_t& handle,
58  const float* input,
59  int n_rows,
60  int n_cols,
61  const float* coef,
62  float intercept,
63  float* preds,
64  int loss);
65 
66 void sgdPredict(raft::handle_t& handle,
67  const double* input,
68  int n_rows,
69  int n_cols,
70  const double* coef,
71  double intercept,
72  double* preds,
73  int loss);
74 
75 void sgdPredictBinaryClass(raft::handle_t& handle,
76  const float* input,
77  int n_rows,
78  int n_cols,
79  const float* coef,
80  float intercept,
81  float* preds,
82  int loss);
83 
84 void sgdPredictBinaryClass(raft::handle_t& handle,
85  const double* input,
86  int n_rows,
87  int n_cols,
88  const double* coef,
89  double intercept,
90  double* preds,
91  int loss);
92 
140 int cdFit(raft::handle_t& handle,
141  float* input,
142  int n_rows,
143  int n_cols,
144  float* labels,
145  float* coef,
146  float* intercept,
147  bool fit_intercept,
148  int epochs,
149  int loss,
150  float alpha,
151  float l1_ratio,
152  bool shuffle,
153  float tol,
154  float* sample_weight = nullptr);
155 
156 int cdFit(raft::handle_t& handle,
157  double* input,
158  int n_rows,
159  int n_cols,
160  double* labels,
161  double* coef,
162  double* intercept,
163  bool fit_intercept,
164  int epochs,
165  int loss,
166  double alpha,
167  double l1_ratio,
168  bool shuffle,
169  double tol,
170  double* sample_weight = nullptr);
171 
172 void cdPredict(raft::handle_t& handle,
173  const float* input,
174  int n_rows,
175  int n_cols,
176  const float* coef,
177  float intercept,
178  float* preds,
179  int loss);
180 
181 void cdPredict(raft::handle_t& handle,
182  const double* input,
183  int n_rows,
184  int n_cols,
185  const double* coef,
186  double intercept,
187  double* preds,
188  int loss);
189 
190 }; // namespace Solver
191 }; // end namespace ML
void sgdPredictBinaryClass(raft::handle_t &handle, const float *input, int n_rows, int n_cols, const float *coef, float intercept, float *preds, int loss)
void cdPredict(raft::handle_t &handle, const float *input, int n_rows, int n_cols, const float *coef, float intercept, float *preds, int loss)
void sgdFit(raft::handle_t &handle, float *input, int n_rows, int n_cols, float *labels, float *coef, float *intercept, bool fit_intercept, int batch_size, int epochs, int lr_type, float eta0, float power_t, int loss, int penalty, float alpha, float l1_ratio, bool shuffle, float tol, int n_iter_no_change)
void shuffle(std::vector< math_t > &rand_indices, std::mt19937 &g)
Definition: shuffle.h:24
void sgdPredict(raft::handle_t &handle, const float *input, int n_rows, int n_cols, const float *coef, float intercept, float *preds, int loss)
int cdFit(raft::handle_t &handle, float *input, int n_rows, int n_cols, float *labels, float *coef, float *intercept, bool fit_intercept, int epochs, int loss, float alpha, float l1_ratio, bool shuffle, float tol, float *sample_weight=nullptr)
Definition: dbscan.hpp:18
lr_type
Definition: params.hpp:10
penalty
Definition: params.hpp:23
Definition: dbscan.hpp:14