lars.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2022, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <raft/core/handle.hpp>
20 
21 namespace ML {
22 namespace Solver {
23 namespace Lars {
24 
56 template <typename math_t, typename idx_t>
57 void larsFit(const raft::handle_t& handle,
58  math_t* X,
59  idx_t n_rows,
60  idx_t n_cols,
61  const math_t* y,
62  math_t* beta,
63  idx_t* active_idx,
64  math_t* alphas,
65  idx_t* n_active,
66  math_t* Gram,
67  int max_iter,
68  math_t* coef_path,
69  int verbosity,
70  idx_t ld_X,
71  idx_t ld_G,
72  math_t eps);
73 
91 template <typename math_t, typename idx_t>
92 void larsPredict(const raft::handle_t& handle,
93  const math_t* X,
94  idx_t n_rows,
95  idx_t n_cols,
96  idx_t ld_X,
97  const math_t* beta,
98  idx_t n_active,
99  idx_t* active_idx,
100  math_t intercept,
101  math_t* preds);
102 }; // namespace Lars
103 }; // namespace Solver
104 }; // end namespace ML
void larsPredict(const raft::handle_t &handle, const math_t *X, idx_t n_rows, idx_t n_cols, idx_t ld_X, const math_t *beta, idx_t n_active, idx_t *active_idx, math_t intercept, math_t *preds)
Predict with LARS regressor.
void larsFit(const raft::handle_t &handle, math_t *X, idx_t n_rows, idx_t n_cols, const math_t *y, math_t *beta, idx_t *active_idx, math_t *alphas, idx_t *n_active, math_t *Gram, int max_iter, math_t *coef_path, int verbosity, idx_t ld_X, idx_t ld_G, math_t eps)
Train a regressor using LARS method.
Definition: dbscan.hpp:30