C-Support Vector Classification. More...
#include <svc.hpp>
Public Member Functions | |
SVC (raft::handle_t &handle, math_t C=1, math_t tol=1.0e-3, raft::distance::kernels::KernelParams kernel_params=raft::distance::kernels::KernelParams{raft::distance::kernels::LINEAR, 3, 1, 0}, math_t cache_size=200, int max_iter=-1, int nochange_steps=1000, int verbosity=CUML_LEVEL_INFO) | |
Constructs a support vector classifier. More... | |
~SVC () | |
void | fit (math_t *input, int n_rows, int n_cols, math_t *labels, const math_t *sample_weight=nullptr) |
Fit a support vector classifier to the training data. More... | |
void | predict (math_t *input, int n_rows, int n_cols, math_t *preds) |
Predict classes for samples in input. More... | |
void | decisionFunction (math_t *input, int n_rows, int n_cols, math_t *preds) |
Calculate decision function value for samples in input. More... | |
Public Attributes | |
raft::distance::kernels::KernelParams | kernel_params |
SvmParameter | param |
SvmModel< math_t > | model |
C-Support Vector Classification.
This is a Scikit-Learn like wrapper around the stateless C++ functions. See Issue #456 for general discussion about stateful Sklearn like wrappers.
The classifier will be fitted using the SMO algorithm in dual space.
The decision function takes the following form
\[ sign\left( \sum_{i=1}^{N_{support}} y_i \alpha_i K(x_i,x) + b \right), \]
where \(x_i\) are the support vectors, and \( y_i \alpha_i \) are the dual coordinates.
The penalty parameter C limits the values of the dual coefficients
\[ 0 <= \alpha <= C \]
ML::SVM::SVC< math_t >::SVC | ( | raft::handle_t & | handle, |
math_t | C = 1 , |
||
math_t | tol = 1.0e-3 , |
||
raft::distance::kernels::KernelParams | kernel_params = raft::distance::kernels::KernelParams{raft::distance::kernels::LINEAR, 3, 1, 0} , |
||
math_t | cache_size = 200 , |
||
int | max_iter = -1 , |
||
int | nochange_steps = 1000 , |
||
int | verbosity = CUML_LEVEL_INFO |
||
) |
Constructs a support vector classifier.
handle | cuML handle |
C | penalty term |
tol | tolerance to stop fitting |
kernel_params | parameters for kernels |
cache_size | size of kernel cache in device memory (MiB) |
max_iter | maximum number of outer iterations in SmoSolver |
nochange_steps | number of steps with no change wrt convergence |
verbosity | verbosity level for logging messages during execution |
ML::SVM::SVC< math_t >::~SVC | ( | ) |
void ML::SVM::SVC< math_t >::decisionFunction | ( | math_t * | input, |
int | n_rows, | ||
int | n_cols, | ||
math_t * | preds | ||
) |
Calculate decision function value for samples in input.
[in] | input | device pointer for the input data in column major format, size [n_rows x n_cols]. |
[in] | n_rows | number of vectors |
[in] | n_cols | number of features |
[out] | preds | device pointer to store the decision function value Size [n_rows]. Should be allocated on entry. |
void ML::SVM::SVC< math_t >::fit | ( | math_t * | input, |
int | n_rows, | ||
int | n_cols, | ||
math_t * | labels, | ||
const math_t * | sample_weight = nullptr |
||
) |
Fit a support vector classifier to the training data.
Each row of the input data stores a feature vector. We use the SMO method to fit the SVM.
input | device pointer for the input data in column major format. Size n_rows x n_cols. | |
n_rows | number of rows | |
n_cols | number of columns | |
labels | device pointer for the labels. Size n_rows. | |
[in] | sample_weight | optional sample weights, size [n_rows] |
void ML::SVM::SVC< math_t >::predict | ( | math_t * | input, |
int | n_rows, | ||
int | n_cols, | ||
math_t * | preds | ||
) |
Predict classes for samples in input.
[in] | input | device pointer for the input data in column major format, size [n_rows x n_cols]. |
[in] | n_rows | number of vectors |
[in] | n_cols | number of features |
[out] | preds | device pointer to store the predicted class labels. Size [n_rows]. Should be allocated on entry. |
raft::distance::kernels::KernelParams ML::SVM::SVC< math_t >::kernel_params |
SvmModel<math_t> ML::SVM::SVC< math_t >::model |
SvmParameter ML::SVM::SVC< math_t >::param |