smosolver.h
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 #include <cuml/common/logger.hpp>
9 #include <cuml/svm/svm_model.h>
10 
11 #include <raft/core/handle.hpp>
12 
13 #include <rmm/device_scalar.hpp>
14 #include <rmm/device_uvector.hpp>
15 
16 #include <thrust/device_ptr.h>
17 
18 #include <cuvs/distance/distance.hpp>
19 #include <cuvs/distance/grammian.hpp>
20 
21 #include <cassert>
22 #include <chrono>
23 #include <cstdlib>
24 #include <iostream>
25 #include <limits>
26 #include <sstream>
27 #include <string>
28 #include <type_traits>
29 
30 namespace ML {
31 namespace SVM {
32 
58 template <typename math_t>
59 class SmoSolver {
60  public:
61  SmoSolver(const raft::handle_t& handle,
62  SvmParameter param,
64  cuvs::distance::kernels::GramMatrixBase<math_t>* kernel,
65  bool is_precomputed = false)
66  : handle(handle),
67  C(param.C),
68  tol(param.tol),
69  kernel(kernel),
70  kernel_type(kernel_type),
71  is_precomputed(is_precomputed),
72  cache_size(param.cache_size),
73  nochange_steps(param.nochange_steps),
74  epsilon(param.epsilon),
75  svmType(param.svmType),
76  stream(handle.get_stream()),
77  return_buff(2, stream),
78  alpha(0, stream),
79  C_vec(0, stream),
80  delta_alpha(0, stream),
81  f(0, stream),
82  y_label(0, stream)
83  {
84  ML::default_logger().set_level(param.verbosity);
85  }
86 
87  void GetNonzeroDeltaAlpha(const math_t* vec,
88  int n_ws,
89  const int* idx,
90  math_t* nz_vec,
91  int* n_nz,
92  int* nz_idx,
93  cudaStream_t stream);
116  template <typename MatrixViewType>
117  void Solve(MatrixViewType matrix,
118  int n_rows,
119  int n_cols,
120  math_t* y,
121  const math_t* sample_weight,
122  math_t** dual_coefs,
123  int* n_support,
124  SupportStorage<math_t>* support_matrix,
125  int** idx,
126  math_t* b,
127  int max_iter = -1,
128  int max_outer_iter = -1,
129  int max_inner_iter = 10000);
130 
146  void UpdateF(math_t* f, int n_rows, const math_t* delta_alpha, int n_ws, const math_t* cacheTile);
147 
169  void Initialize(math_t** y, const math_t* sample_weight, int n_rows, int n_cols);
170 
171  void InitPenalty(math_t* C_vec, const math_t* sample_weight, int n_rows);
172 
185  void SvcInit(const math_t* y);
186 
220  void SvrInit(const math_t* yr, int n_rows, math_t* yc, math_t* f);
221 
222  int GetNIter() { return n_iter; };
223 
224  private:
225  const raft::handle_t& handle;
226  cudaStream_t stream;
227 
228  int n_rows = 0;
229  int n_cols = 0;
230  int n_ws = 0;
231  int n_train = 0;
232 
233  // Buffers for the domain [n_train]
234  rmm::device_uvector<math_t> alpha;
235  rmm::device_uvector<math_t> f;
236  rmm::device_uvector<math_t> y_label;
237 
238  rmm::device_uvector<math_t> C_vec;
239 
240  // Buffers for the working set [n_ws]
242  rmm::device_uvector<math_t> delta_alpha;
243 
244  // Buffers to return some parameters from the kernel (iteration number, and
245  // convergence information)
246  rmm::device_uvector<math_t> return_buff;
247  math_t host_return_buff[2];
248 
249  math_t C;
250  math_t tol;
251  math_t epsilon;
252 
253  cuvs::distance::kernels::GramMatrixBase<math_t>* kernel;
255  bool is_precomputed;
256  float cache_size;
257 
258  SvmType svmType;
259 
260  // Variables to track convergence of training
261  math_t diff_prev;
262  int n_small_diff;
263  int nochange_steps;
264  int n_increased_diff;
265  int n_outer_iter;
266  int n_iter;
267  bool report_increased_diff;
268 
269  bool CheckStoppingCondition(math_t diff)
270  {
271  if (diff > diff_prev * 1.5 && n_outer_iter > 0) {
272  // Ideally, diff should decrease monotonically. In practice we can have
273  // small fluctuations (10% increase is not uncommon). Here we consider a
274  // 50% increase in the diff value large enough to indicate a problem.
275  // The 50% value is an educated guess that triggers the convergence debug
276  // message for problematic use cases while avoids false alarms in many
277  // other cases.
278  n_increased_diff++;
279  }
280  if (report_increased_diff && n_outer_iter > 100 && n_increased_diff > n_outer_iter * 0.1) {
281  CUML_LOG_DEBUG(
282  "Solver is not converging monotonically. This might be caused by "
283  "insufficient normalization of the feature columns. In that case "
284  "MinMaxScaler((0,1)) could help. Alternatively, for nonlinear kernels, "
285  "you can try to increase the gamma parameter. To limit execution time, "
286  "you can also adjust the number of iterations using the max_iter "
287  "parameter.");
288  report_increased_diff = false;
289  }
290  bool keep_going = true;
291  if (abs(diff - diff_prev) < 0.001 * tol) {
292  n_small_diff++;
293  } else {
294  diff_prev = diff;
295  n_small_diff = 0;
296  }
297  if (n_small_diff > nochange_steps) {
298  CUML_LOG_ERROR(
299  "SMO error: Stopping due to unchanged diff over %d"
300  " consecutive steps",
301  nochange_steps);
302  keep_going = false;
303  }
304  if (diff < tol) keep_going = false;
305  if (isnan(diff)) {
306  std::string txt;
307  if (std::is_same<float, math_t>::value) {
308  txt +=
309  " This might be caused by floating point overflow. In such case using"
310  " fp64 could help. Alternatively, try gamma='scale' kernel"
311  " parameter.";
312  }
313  THROW("SMO error: NaN found during fitting.%s", txt.c_str());
314  }
315  return keep_going;
316  }
317 
319  int GetDefaultMaxIter(int n_train, int max_outer_iter)
320  {
321  if (max_outer_iter == -1) {
322  max_outer_iter = n_train < std::numeric_limits<int>::max() / 100
323  ? n_train * 100
325  max_outer_iter = max(100000, max_outer_iter);
326  }
327  // else we have user defined iteration count which we do not change
328  return max_outer_iter;
329  }
330 
331  void ResizeBuffers(int n_train, int n_cols)
332  {
333  // This needs to know n_train, therefore it can be only called during solve
334  alpha.resize(n_train, stream);
335  C_vec.resize(n_train, stream);
336  f.resize(n_train, stream);
337  delta_alpha.resize(n_ws, stream);
338  if (svmType == EPSILON_SVR) y_label.resize(n_train, stream);
339  }
340 
341  void ReleaseBuffers()
342  {
343  alpha.resize(0, stream);
344  delta_alpha.resize(0, stream);
345  f.resize(0, stream);
346  y_label.resize(0, stream);
347  }
348 };
349 
350 }; // end namespace SVM
351 }; // end namespace ML
Solve the quadratic optimization problem using two level decomposition and Sequential Minimal Optimiz...
Definition: smosolver.h:59
int GetNIter()
Definition: smosolver.h:222
void SvrInit(const math_t *yr, int n_rows, math_t *yc, math_t *f)
Initializes the solver for epsilon-SVR.
void UpdateF(math_t *f, int n_rows, const math_t *delta_alpha, int n_ws, const math_t *cacheTile)
Update the f vector after a block solve step.
void Initialize(math_t **y, const math_t *sample_weight, int n_rows, int n_cols)
Initialize the problem to solve.
void SvcInit(const math_t *y)
Initialize Support Vector Classification.
void GetNonzeroDeltaAlpha(const math_t *vec, int n_ws, const int *idx, math_t *nz_vec, int *n_nz, int *nz_idx, cudaStream_t stream)
SmoSolver(const raft::handle_t &handle, SvmParameter param, cuvs::distance::kernels::KernelType kernel_type, cuvs::distance::kernels::GramMatrixBase< math_t > *kernel, bool is_precomputed=false)
Definition: smosolver.h:61
void InitPenalty(math_t *C_vec, const math_t *sample_weight, int n_rows)
void Solve(MatrixViewType matrix, int n_rows, int n_cols, math_t *y, const math_t *sample_weight, math_t **dual_coefs, int *n_support, SupportStorage< math_t > *support_matrix, int **idx, math_t *b, int max_iter=-1, int max_outer_iter=-1, int max_inner_iter=10000)
Solve the quadratic optimization problem.
SvmType
Definition: svm_parameter.h:12
@ EPSILON_SVR
Definition: svm_parameter.h:12
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
KernelType
Definition: kernel_params.hpp:16
Definition: dbscan.hpp:18
rapids_logger::logger & default_logger()
Get the default logger.
Definition: logger.hpp:43
Definition: svm_model.h:12
Definition: svm_parameter.h:27
rapids_logger::level_enum verbosity
Print information about training.
Definition: svm_parameter.h:34