randomforest.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cuml/common/logger.hpp>
22 
23 #include <map>
24 #include <memory>
25 
26 namespace raft {
27 class handle_t; // forward decl
28 }
29 
30 namespace ML {
31 
32 enum RF_type {
35 };
36 
38 
39 struct RF_metrics {
41 
42  // Classification metrics
43  float accuracy;
44 
45  // Regression metrics
49 };
50 
52  float accuracy,
53  double mean_abs_error,
54  double mean_squared_error,
55  double median_abs_error);
57 RF_metrics set_rf_metrics_regression(double mean_abs_error,
58  double mean_squared_error,
59  double median_abs_error);
60 void print(const RF_metrics rf_metrics);
61 
62 struct RF_params {
66  int n_trees;
77  bool bootstrap;
81  float max_samples;
88  uint64_t seed;
94  int n_streams;
96 };
97 
98 /* Update labels so they are unique from 0 to n_unique_vals.
99  Create an old_label to new_label map per random forest.
100 */
101 void preprocess_labels(int n_rows,
102  std::vector<int>& labels,
103  std::map<int, int>& labels_map,
104  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
105 
106 /* Revert preprocessing effect, if needed. */
107 void postprocess_labels(int n_rows,
108  std::vector<int>& labels,
109  std::map<int, int>& labels_map,
110  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
111 
112 template <class T, class L>
114  std::vector<std::shared_ptr<DT::TreeMetaDataNode<T, L>>> trees;
116 };
117 
118 template <class T, class L>
120 
121 template <class T, class L>
123 
124 template <class T, class L>
126 
127 template <class T, class L>
128 std::string get_rf_json(const RandomForestMetaData<T, L>* forest);
129 
130 template <class T, class L>
132  const RandomForestMetaData<T, L>* forest,
133  int num_features);
134 
135 // ----------------------------- Classification ----------------------------------- //
136 
139 
140 void fit(const raft::handle_t& user_handle,
141  RandomForestClassifierF* forest,
142  float* input,
143  int n_rows,
144  int n_cols,
145  int* labels,
146  int n_unique_labels,
147  RF_params rf_params,
148  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
149 void fit(const raft::handle_t& user_handle,
150  RandomForestClassifierD* forest,
151  double* input,
152  int n_rows,
153  int n_cols,
154  int* labels,
155  int n_unique_labels,
156  RF_params rf_params,
157  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
158 
159 template <typename T, typename L>
160 void fit_treelite(const raft::handle_t& user_handle,
161  TreeliteModelHandle* model,
162  T* input,
163  int n_rows,
164  int n_cols,
165  L* labels,
166  int n_unique_labels,
167  RF_params rf_params,
168  rapids_logger::level_enum verbosity);
169 
170 void predict(const raft::handle_t& user_handle,
171  const RandomForestClassifierF* forest,
172  const float* input,
173  int n_rows,
174  int n_cols,
175  int* predictions,
176  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
177 void predict(const raft::handle_t& user_handle,
178  const RandomForestClassifierD* forest,
179  const double* input,
180  int n_rows,
181  int n_cols,
182  int* predictions,
183  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
184 
185 RF_metrics score(const raft::handle_t& user_handle,
186  const RandomForestClassifierF* forest,
187  const int* ref_labels,
188  int n_rows,
189  const int* predictions,
190  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
191 RF_metrics score(const raft::handle_t& user_handle,
192  const RandomForestClassifierD* forest,
193  const int* ref_labels,
194  int n_rows,
195  const int* predictions,
196  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
197 
198 RF_params set_rf_params(int max_depth,
199  int max_leaves,
200  float max_features,
201  int max_n_bins,
202  int min_samples_leaf,
203  int min_samples_split,
204  float min_impurity_decrease,
205  bool bootstrap,
206  int n_trees,
207  float max_samples,
208  uint64_t seed,
209  CRITERION split_criterion,
210  int cfg_n_streams,
211  int max_batch_size);
212 
213 // ----------------------------- Regression ----------------------------------- //
214 
217 
218 void fit(const raft::handle_t& user_handle,
219  RandomForestRegressorF* forest,
220  float* input,
221  int n_rows,
222  int n_cols,
223  float* labels,
224  RF_params rf_params,
225  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
226 void fit(const raft::handle_t& user_handle,
227  RandomForestRegressorD* forest,
228  double* input,
229  int n_rows,
230  int n_cols,
231  double* labels,
232  RF_params rf_params,
233  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
234 
235 template <typename T, typename L>
236 void fit_treelite(const raft::handle_t& user_handle,
237  TreeliteModelHandle* model,
238  T* input,
239  int n_rows,
240  int n_cols,
241  L* labels,
242  RF_params rf_params,
243  rapids_logger::level_enum verbosity);
244 
245 void predict(const raft::handle_t& user_handle,
246  const RandomForestRegressorF* forest,
247  const float* input,
248  int n_rows,
249  int n_cols,
250  float* predictions,
251  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
252 void predict(const raft::handle_t& user_handle,
253  const RandomForestRegressorD* forest,
254  const double* input,
255  int n_rows,
256  int n_cols,
257  double* predictions,
258  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
259 
260 RF_metrics score(const raft::handle_t& user_handle,
261  const RandomForestRegressorF* forest,
262  const float* ref_labels,
263  int n_rows,
264  const float* predictions,
265  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
266 RF_metrics score(const raft::handle_t& user_handle,
267  const RandomForestRegressorD* forest,
268  const double* ref_labels,
269  int n_rows,
270  const double* predictions,
271  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
272 }; // namespace ML
Definition: dbscan.hpp:29
void preprocess_labels(int n_rows, std::vector< int > &labels, std::map< int, int > &labels_map, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
std::string get_rf_json(const RandomForestMetaData< T, L > *forest)
void fit(const raft::handle_t &user_handle, RandomForestClassifierF *forest, float *input, int n_rows, int n_cols, int *labels, int n_unique_labels, RF_params rf_params, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
std::string get_rf_summary_text(const RandomForestMetaData< T, L > *forest)
RandomForestMetaData< float, int > RandomForestClassifierF
Definition: randomforest.hpp:137
void print(const RF_metrics rf_metrics)
void delete_rf_metadata(RandomForestMetaData< T, L > *forest)
RF_type
Definition: randomforest.hpp:32
@ REGRESSION
Definition: randomforest.hpp:34
@ CLASSIFICATION
Definition: randomforest.hpp:33
RF_metrics set_all_rf_metrics(RF_type rf_type, float accuracy, double mean_abs_error, double mean_squared_error, double median_abs_error)
RandomForestMetaData< double, double > RandomForestRegressorD
Definition: randomforest.hpp:216
RandomForestMetaData< float, float > RandomForestRegressorF
Definition: randomforest.hpp:215
void postprocess_labels(int n_rows, std::vector< int > &labels, std::map< int, int > &labels_map, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
void build_treelite_forest(TreeliteModelHandle *model, const RandomForestMetaData< T, L > *forest, int num_features)
CRITERION
Definition: algo_helper.h:20
RF_metrics set_rf_metrics_classification(float accuracy)
std::string get_rf_detailed_text(const RandomForestMetaData< T, L > *forest)
RF_metrics score(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const int *ref_labels, int n_rows, const int *predictions, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
RF_params set_rf_params(int max_depth, int max_leaves, float max_features, int max_n_bins, int min_samples_leaf, int min_samples_split, float min_impurity_decrease, bool bootstrap, int n_trees, float max_samples, uint64_t seed, CRITERION split_criterion, int cfg_n_streams, int max_batch_size)
RF_metrics set_rf_metrics_regression(double mean_abs_error, double mean_squared_error, double median_abs_error)
void predict(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const float *input, int n_rows, int n_cols, int *predictions, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
RandomForestMetaData< double, int > RandomForestClassifierD
Definition: randomforest.hpp:138
void fit_treelite(const raft::handle_t &user_handle, TreeliteModelHandle *model, T *input, int n_rows, int n_cols, L *labels, int n_unique_labels, RF_params rf_params, rapids_logger::level_enum verbosity)
task_category
Definition: randomforest.hpp:37
@ REGRESSION_MODEL
Definition: randomforest.hpp:37
@ CLASSIFICATION_MODEL
Definition: randomforest.hpp:37
Definition: dbscan.hpp:25
Definition: decisiontree.hpp:29
Definition: randomforest.hpp:39
RF_type rf_type
Definition: randomforest.hpp:40
double mean_squared_error
Definition: randomforest.hpp:47
double median_abs_error
Definition: randomforest.hpp:48
float accuracy
Definition: randomforest.hpp:43
double mean_abs_error
Definition: randomforest.hpp:46
Definition: randomforest.hpp:62
uint64_t seed
Definition: randomforest.hpp:88
int n_streams
Definition: randomforest.hpp:94
DT::DecisionTreeParams tree_params
Definition: randomforest.hpp:95
bool bootstrap
Definition: randomforest.hpp:77
int n_trees
Definition: randomforest.hpp:66
float max_samples
Definition: randomforest.hpp:81
Definition: randomforest.hpp:113
RF_params rf_params
Definition: randomforest.hpp:115
std::vector< std::shared_ptr< DT::TreeMetaDataNode< T, L > > > trees
Definition: randomforest.hpp:114
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23