linear.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 #include <cuml/common/logger.hpp>
9 
10 #include <raft/core/handle.hpp>
11 
12 namespace ML {
13 namespace SVM {
14 namespace linear {
15 
16 struct Params {
18  enum Penalty {
20  L1,
22  L2
23  };
25  enum Loss {
34  };
35 
41  bool fit_intercept = true;
45  bool penalized_intercept = false;
47  int max_iter = 1000;
55  int lbfgs_memory = 5;
57  rapids_logger::level_enum verbose = rapids_logger::level_enum::off;
62  double C = 1.0;
64  double grad_tol = 0.0001;
66  double change_tol = 0.00001;
68  double epsilon = 0.0;
69 };
70 
92 template <typename T>
93 int fit(const raft::handle_t& handle,
94  const Params& params,
95  const std::size_t nRows,
96  const std::size_t nCols,
97  const int nClasses,
98  const T* classes,
99  const T* X,
100  const T* y,
101  const T* sampleWeight,
102  T* w,
103  T* probScale);
104 
116 template <typename T>
117 void computeProbabilities(const raft::handle_t& handle,
118  const std::size_t nRows,
119  const int nClasses,
120  const T* probScale,
121  T* scores,
122  T* out);
123 
124 } // namespace linear
125 } // namespace SVM
126 } // namespace ML
Definition: params.hpp:23
void computeProbabilities(const raft::handle_t &handle, const std::size_t nRows, const int nClasses, const T *probScale, T *scores, T *out)
Compute probabilities from decision function scores.
int fit(const raft::handle_t &handle, const Params ¶ms, const std::size_t nRows, const std::size_t nCols, const int nClasses, const T *classes, const T *X, const T *y, const T *sampleWeight, T *w, T *probScale)
Fit a linear SVM model.
Definition: dbscan.hpp:18
penalty
Definition: params.hpp:23
Definition: linear.hpp:16
double grad_tol
Definition: linear.hpp:64
bool penalized_intercept
Definition: linear.hpp:45
Loss
Definition: linear.hpp:25
@ SQUARED_EPSILON_INSENSITIVE
Definition: linear.hpp:33
@ SQUARED_HINGE
Definition: linear.hpp:29
@ HINGE
Definition: linear.hpp:27
@ EPSILON_INSENSITIVE
Definition: linear.hpp:31
int lbfgs_memory
Definition: linear.hpp:55
int max_iter
Definition: linear.hpp:47
rapids_logger::level_enum verbose
Definition: linear.hpp:57
Loss loss
Definition: linear.hpp:39
bool fit_intercept
Definition: linear.hpp:41
double change_tol
Definition: linear.hpp:66
Penalty
Definition: linear.hpp:18
@ L2
Definition: linear.hpp:22
@ L1
Definition: linear.hpp:20
double C
Definition: linear.hpp:62
double epsilon
Definition: linear.hpp:68
int linesearch_max_iter
Definition: linear.hpp:51