Home
libcuml
cucim
cudf-java
cudf
cugraph
cuml
cuproj
cuspatial
cuvs
cuxfilter
dask-cuda
dask-cudf
kvikio
libcudf
libcuml
libcuproj
libcuspatial
libkvikio
librapidsmpf
librmm
libucxx
raft
rapids-cmake
rapidsmpf
rmm
ucxx
stable (25.12)
nightly (26.02)
stable (25.12)
legacy (25.10)
include
cuml
linear_model
qn.h
Go to the documentation of this file.
1
/*
2
* SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION.
3
* SPDX-License-Identifier: Apache-2.0
4
*/
5
#pragma once
6
7
#include <stdbool.h>
8
9
#ifdef __cplusplus
10
namespace
ML::GLM
{
11
12
extern
"C"
{
13
#endif
14
16
enum
qn_loss_type
{
20
QN_LOSS_LOGISTIC
= 0,
24
QN_LOSS_SQUARED
= 1,
28
QN_LOSS_SOFTMAX
= 2,
32
QN_LOSS_SVC_L1
= 3,
36
QN_LOSS_SVC_L2
= 4,
40
QN_LOSS_SVR_L1
= 5,
44
QN_LOSS_SVR_L2
= 6,
48
QN_LOSS_ABS
= 7,
50
QN_LOSS_UNKNOWN
= 99
51
};
52
#ifndef __cplusplus
53
typedef
enum
qn_loss_type
qn_loss_type
;
54
#endif
55
56
struct
qn_params
{
58
qn_loss_type
loss
;
60
double
penalty_l1
;
62
double
penalty_l2
;
64
double
grad_tol
;
66
double
change_tol
;
68
int
max_iter
;
70
int
linesearch_max_iter
;
72
int
lbfgs_memory
;
74
int
verbose
;
76
bool
fit_intercept
;
85
bool
penalty_normalized
;
86
87
#ifdef __cplusplus
88
qn_params
()
89
: loss(
QN_LOSS_UNKNOWN
),
90
penalty_l1(0),
91
penalty_l2(0),
92
grad_tol(1e-4),
93
change_tol(1e-5),
94
max_iter(1000),
95
linesearch_max_iter(50),
96
lbfgs_memory(5),
97
verbose(0),
98
fit_intercept(true),
99
penalty_normalized(true)
100
{
101
}
102
#endif
103
};
104
105
#ifndef __cplusplus
106
typedef
struct
qn_params
qn_params
;
107
#endif
108
109
#ifdef __cplusplus
110
}
111
}
112
#endif
ML::GLM
Definition:
glm.hpp:12
qn_params
struct qn_params qn_params
Definition:
qn.h:106
qn_loss_type
qn_loss_type
Definition:
qn.h:16
QN_LOSS_UNKNOWN
@ QN_LOSS_UNKNOWN
Definition:
qn.h:50
QN_LOSS_SOFTMAX
@ QN_LOSS_SOFTMAX
Definition:
qn.h:28
QN_LOSS_SVR_L1
@ QN_LOSS_SVR_L1
Definition:
qn.h:40
QN_LOSS_SQUARED
@ QN_LOSS_SQUARED
Definition:
qn.h:24
QN_LOSS_SVR_L2
@ QN_LOSS_SVR_L2
Definition:
qn.h:44
QN_LOSS_SVC_L1
@ QN_LOSS_SVC_L1
Definition:
qn.h:32
QN_LOSS_ABS
@ QN_LOSS_ABS
Definition:
qn.h:48
QN_LOSS_SVC_L2
@ QN_LOSS_SVC_L2
Definition:
qn.h:36
QN_LOSS_LOGISTIC
@ QN_LOSS_LOGISTIC
Definition:
qn.h:20
qn_params
Definition:
qn.h:56
qn_params::fit_intercept
bool fit_intercept
Definition:
qn.h:76
qn_params::lbfgs_memory
int lbfgs_memory
Definition:
qn.h:72
qn_params::grad_tol
double grad_tol
Definition:
qn.h:64
qn_params::penalty_normalized
bool penalty_normalized
Definition:
qn.h:85
qn_params::penalty_l1
double penalty_l1
Definition:
qn.h:60
qn_params::loss
qn_loss_type loss
Definition:
qn.h:58
qn_params::verbose
int verbose
Definition:
qn.h:74
qn_params::change_tol
double change_tol
Definition:
qn.h:66
qn_params::max_iter
int max_iter
Definition:
qn.h:68
qn_params::penalty_l2
double penalty_l2
Definition:
qn.h:62
qn_params::linesearch_max_iter
int linesearch_max_iter
Definition:
qn.h:70
Generated by
1.9.1