svr.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
9 
10 #include <cublas_v2.h>
11 
12 namespace ML {
13 namespace SVM {
14 
15 template <typename math_t>
16 struct SvmModel;
17 struct SvmParameter;
18 
19 // Forward declarations of the stateless API
40 template <typename math_t>
41 int svrFit(const raft::handle_t& handle,
42  math_t* X,
43  int n_rows,
44  int n_cols,
45  math_t* y,
46  const SvmParameter& param,
47  ML::matrix::KernelParams& kernel_params,
48  SvmModel<math_t>& model,
49  const math_t* sample_weight = nullptr);
50 
73 template <typename math_t>
74 int svrFitSparse(const raft::handle_t& handle,
75  int* indptr,
76  int* indices,
77  math_t* data,
78  int n_rows,
79  int n_cols,
80  int nnz,
81  math_t* y,
82  const SvmParameter& param,
83  ML::matrix::KernelParams& kernel_params,
84  SvmModel<math_t>& model,
85  const math_t* sample_weight = nullptr);
86 
87 // For prediction we use svcPredict
88 
89 }; // end namespace SVM
90 }; // end namespace ML
int svrFit(const raft::handle_t &handle, math_t *X, int n_rows, int n_cols, math_t *y, const SvmParameter ¶m, ML::matrix::KernelParams &kernel_params, SvmModel< math_t > &model, const math_t *sample_weight=nullptr)
Fit a support vector regressor to the training data.
int svrFitSparse(const raft::handle_t &handle, int *indptr, int *indices, math_t *data, int n_rows, int n_cols, int nnz, math_t *y, const SvmParameter ¶m, ML::matrix::KernelParams &kernel_params, SvmModel< math_t > &model, const math_t *sample_weight=nullptr)
Fit a support vector regressor to the training data.
Definition: dbscan.hpp:18
Definition: svm_model.h:24
Definition: svm_parameter.h:27
Definition: kernel_params.hpp:18