svr.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
20 
21 #include <cublas_v2.h>
22 
23 namespace ML {
24 namespace SVM {
25 
26 template <typename math_t>
27 struct SvmModel;
28 struct SvmParameter;
29 
30 // Forward declarations of the stateless API
50 template <typename math_t>
51 void svrFit(const raft::handle_t& handle,
52  math_t* X,
53  int n_rows,
54  int n_cols,
55  math_t* y,
56  const SvmParameter& param,
57  MLCommon::Matrix::KernelParams& kernel_params,
58  SvmModel<math_t>& model,
59  const math_t* sample_weight = nullptr);
60 
82 template <typename math_t>
83 void svrFitSparse(const raft::handle_t& handle,
84  int* indptr,
85  int* indices,
86  math_t* data,
87  int n_rows,
88  int n_cols,
89  int nnz,
90  math_t* y,
91  const SvmParameter& param,
92  raft::distance::kernels::KernelParams& kernel_params,
93  SvmModel<math_t>& model,
94  const math_t* sample_weight = nullptr);
95 
96 // For prediction we use svcPredict
97 
98 }; // end namespace SVM
99 }; // end namespace ML
void svrFit(const raft::handle_t &handle, math_t *X, int n_rows, int n_cols, math_t *y, const SvmParameter &param, MLCommon::Matrix::KernelParams &kernel_params, SvmModel< math_t > &model, const math_t *sample_weight=nullptr)
Fit a support vector regressor to the training data.
void svrFitSparse(const raft::handle_t &handle, int *indptr, int *indices, math_t *data, int n_rows, int n_cols, int nnz, math_t *y, const SvmParameter &param, raft::distance::kernels::KernelParams &kernel_params, SvmModel< math_t > &model, const math_t *sample_weight=nullptr)
Fit a support vector regressor to the training data.
Definition: dbscan.hpp:30
Definition: svm_model.h:35
Definition: svm_parameter.h:34