Loading [MathJax]/extensions/tex2jax.js
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
lars.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cuml/common/logger.hpp>
20 
21 #include <raft/core/handle.hpp>
22 
23 namespace ML {
24 namespace Solver {
25 namespace Lars {
26 
58 template <typename math_t, typename idx_t>
59 void larsFit(const raft::handle_t& handle,
60  math_t* X,
61  idx_t n_rows,
62  idx_t n_cols,
63  const math_t* y,
64  math_t* beta,
65  idx_t* active_idx,
66  math_t* alphas,
67  idx_t* n_active,
68  math_t* Gram,
69  int max_iter,
70  math_t* coef_path,
71  rapids_logger::level_enum verbosity,
72  idx_t ld_X,
73  idx_t ld_G,
74  math_t eps);
75 
93 template <typename math_t, typename idx_t>
94 void larsPredict(const raft::handle_t& handle,
95  const math_t* X,
96  idx_t n_rows,
97  idx_t n_cols,
98  idx_t ld_X,
99  const math_t* beta,
100  idx_t n_active,
101  idx_t* active_idx,
102  math_t intercept,
103  math_t* preds);
104 }; // namespace Lars
105 }; // namespace Solver
106 }; // end namespace ML
void larsFit(const raft::handle_t &handle, math_t *X, idx_t n_rows, idx_t n_cols, const math_t *y, math_t *beta, idx_t *active_idx, math_t *alphas, idx_t *n_active, math_t *Gram, int max_iter, math_t *coef_path, rapids_logger::level_enum verbosity, idx_t ld_X, idx_t ld_G, math_t eps)
Train a regressor using LARS method.
void larsPredict(const raft::handle_t &handle, const math_t *X, idx_t n_rows, idx_t n_cols, idx_t ld_X, const math_t *beta, idx_t n_active, idx_t *active_idx, math_t intercept, math_t *preds)
Predict with LARS regressor.
Definition: dbscan.hpp:30