solver.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2018-2022, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 namespace raft {
20 class handle_t;
21 }
22 
23 namespace ML {
24 namespace Solver {
25 
26 void sgdFit(raft::handle_t& handle,
27  float* input,
28  int n_rows,
29  int n_cols,
30  float* labels,
31  float* coef,
32  float* intercept,
33  bool fit_intercept,
34  int batch_size,
35  int epochs,
36  int lr_type,
37  float eta0,
38  float power_t,
39  int loss,
40  int penalty,
41  float alpha,
42  float l1_ratio,
43  bool shuffle,
44  float tol,
45  int n_iter_no_change);
46 
47 void sgdFit(raft::handle_t& handle,
48  double* input,
49  int n_rows,
50  int n_cols,
51  double* labels,
52  double* coef,
53  double* intercept,
54  bool fit_intercept,
55  int batch_size,
56  int epochs,
57  int lr_type,
58  double eta0,
59  double power_t,
60  int loss,
61  int penalty,
62  double alpha,
63  double l1_ratio,
64  bool shuffle,
65  double tol,
66  int n_iter_no_change);
67 
68 void sgdPredict(raft::handle_t& handle,
69  const float* input,
70  int n_rows,
71  int n_cols,
72  const float* coef,
73  float intercept,
74  float* preds,
75  int loss);
76 
77 void sgdPredict(raft::handle_t& handle,
78  const double* input,
79  int n_rows,
80  int n_cols,
81  const double* coef,
82  double intercept,
83  double* preds,
84  int loss);
85 
86 void sgdPredictBinaryClass(raft::handle_t& handle,
87  const float* input,
88  int n_rows,
89  int n_cols,
90  const float* coef,
91  float intercept,
92  float* preds,
93  int loss);
94 
95 void sgdPredictBinaryClass(raft::handle_t& handle,
96  const double* input,
97  int n_rows,
98  int n_cols,
99  const double* coef,
100  double intercept,
101  double* preds,
102  int loss);
103 
152 void cdFit(raft::handle_t& handle,
153  float* input,
154  int n_rows,
155  int n_cols,
156  float* labels,
157  float* coef,
158  float* intercept,
159  bool fit_intercept,
160  bool normalize,
161  int epochs,
162  int loss,
163  float alpha,
164  float l1_ratio,
165  bool shuffle,
166  float tol,
167  float* sample_weight = nullptr);
168 
169 void cdFit(raft::handle_t& handle,
170  double* input,
171  int n_rows,
172  int n_cols,
173  double* labels,
174  double* coef,
175  double* intercept,
176  bool fit_intercept,
177  bool normalize,
178  int epochs,
179  int loss,
180  double alpha,
181  double l1_ratio,
182  bool shuffle,
183  double tol,
184  double* sample_weight = nullptr);
185 
186 void cdPredict(raft::handle_t& handle,
187  const float* input,
188  int n_rows,
189  int n_cols,
190  const float* coef,
191  float intercept,
192  float* preds,
193  int loss);
194 
195 void cdPredict(raft::handle_t& handle,
196  const double* input,
197  int n_rows,
198  int n_cols,
199  const double* coef,
200  double intercept,
201  double* preds,
202  int loss);
203 
204 }; // namespace Solver
205 }; // end namespace ML
void normalize(value_t *data, value_idx n, size_t m, cudaStream_t stream)
Definition: utils.h:194
void sgdPredictBinaryClass(raft::handle_t &handle, const float *input, int n_rows, int n_cols, const float *coef, float intercept, float *preds, int loss)
void cdPredict(raft::handle_t &handle, const float *input, int n_rows, int n_cols, const float *coef, float intercept, float *preds, int loss)
void sgdFit(raft::handle_t &handle, float *input, int n_rows, int n_cols, float *labels, float *coef, float *intercept, bool fit_intercept, int batch_size, int epochs, int lr_type, float eta0, float power_t, int loss, int penalty, float alpha, float l1_ratio, bool shuffle, float tol, int n_iter_no_change)
void cdFit(raft::handle_t &handle, float *input, int n_rows, int n_cols, float *labels, float *coef, float *intercept, bool fit_intercept, bool normalize, int epochs, int loss, float alpha, float l1_ratio, bool shuffle, float tol, float *sample_weight=nullptr)
void shuffle(std::vector< math_t > &rand_indices, std::mt19937 &g)
Definition: shuffle.h:35
void sgdPredict(raft::handle_t &handle, const float *input, int n_rows, int n_cols, const float *coef, float intercept, float *preds, int loss)
Definition: dbscan.hpp:30
lr_type
Definition: params.hpp:21
penalty
Definition: params.hpp:34
Definition: dbscan.hpp:26