fil.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
19 #pragma once
20 
22 
23 #include <stddef.h>
24 
25 #include <variant> // for std::get<>, std::variant<>
26 
27 namespace raft {
28 class handle_t;
29 }
30 
31 namespace ML {
32 namespace fil {
33 
38 enum algo_t {
52 };
53 
69 };
70 static const char* storage_type_repr[] = {"AUTO", "DENSE", "SPARSE", "SPARSE8"};
71 
82 };
83 
84 template <typename real_t>
85 struct forest;
86 
88 template <typename real_t>
90 
94 
96 using forest_variant = std::variant<forest_t<float>, forest_t<double>>;
97 
99 constexpr int MAX_N_ITEMS = 4;
100 
103  // algo is the inference algorithm
105  // output_class indicates whether thresholding will be applied
106  // to the model output
108  // threshold may be used for thresholding if output_class == true,
109  // and is ignored otherwise. threshold is ignored if leaves store
110  // vectorized class labels. in that case, a class with most votes
111  // is returned regardless of the absolute vote count
112  float threshold;
113  // storage_type indicates whether the forest should be imported as dense or sparse
115  // blocks_per_sm, if nonzero, works as a limit to improve cache hit rate for larger forests
116  // suggested values (if nonzero) are from 2 to 7
117  // if zero, launches ceildiv(num_rows, NITEMS) blocks
119  // threads_per_tree determines how many threads work on a single tree at once inside a block
120  // can only be a power of 2
122  // n_items is how many input samples (items) any thread processes. If 0 is given,
123  // choose most (up to MAX_N_ITEMS) that fit into shared memory.
124  int n_items;
125  // if non-nullptr, *pforest_shape_str will be set to caller-owned string that
126  // contains forest shape
128  // precision in which to load the treelite model
130 };
131 
138 void from_treelite(const raft::handle_t& handle,
139  forest_variant* pforest,
140  TreeliteModelHandle model,
141  const treelite_params_t* tl_params);
142 
147 template <typename real_t>
148 void free(const raft::handle_t& h, forest_t<real_t> f);
149 
162 template <typename real_t>
163 void predict(const raft::handle_t& h,
165  real_t* preds,
166  const real_t* data,
167  size_t num_rows,
168  bool predict_proba = false);
169 
170 } // namespace fil
171 } // namespace ML
void free(const raft::handle_t &h, forest_t< real_t > f)
algo_t
Definition: fil.h:38
@ ALGO_AUTO
Definition: fil.h:41
@ BATCH_TREE_REORG
Definition: fil.h:51
@ TREE_REORG
Definition: fil.h:48
@ NAIVE
Definition: fil.h:44
constexpr int MAX_N_ITEMS
Definition: fil.h:99
storage_type_t
Definition: fil.h:55
@ DENSE
Definition: fil.h:59
@ SPARSE
Definition: fil.h:61
@ AUTO
Definition: fil.h:57
@ SPARSE8
Definition: fil.h:68
void from_treelite(const raft::handle_t &handle, forest_variant *pforest, TreeliteModelHandle model, const treelite_params_t *tl_params)
void predict(const raft::handle_t &h, forest_t< real_t > f, real_t *preds, const real_t *data, size_t num_rows, bool predict_proba=false)
precision_t
Definition: fil.h:73
@ PRECISION_NATIVE
Definition: fil.h:76
@ PRECISION_FLOAT64
Definition: fil.h:81
@ PRECISION_FLOAT32
Definition: fil.h:79
std::variant< forest_t< float >, forest_t< double > > forest_variant
Definition: fil.h:96
Definition: dbscan.hpp:30
Definition: dbscan.hpp:26
Definition: fil.h:85
Definition: fil.h:102
bool output_class
Definition: fil.h:107
float threshold
Definition: fil.h:112
storage_type_t storage_type
Definition: fil.h:114
int n_items
Definition: fil.h:124
int threads_per_tree
Definition: fil.h:121
int blocks_per_sm
Definition: fil.h:118
algo_t algo
Definition: fil.h:104
char ** pforest_shape_str
Definition: fil.h:127
precision_t precision
Definition: fil.h:129
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23