randomforest.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
8 #include <cuml/common/logger.hpp>
11 
12 #include <map>
13 #include <memory>
14 
15 namespace raft {
16 class handle_t; // forward decl
17 }
18 
19 namespace ML {
20 
21 enum RF_type {
24 };
25 
27 
28 struct RF_metrics {
30 
31  // Classification metrics
32  float accuracy;
33 
34  // Regression metrics
38 };
39 
41  float accuracy,
42  double mean_abs_error,
43  double mean_squared_error,
44  double median_abs_error);
46 RF_metrics set_rf_metrics_regression(double mean_abs_error,
47  double mean_squared_error,
48  double median_abs_error);
49 void print(const RF_metrics rf_metrics);
50 
51 struct RF_params {
55  int n_trees;
66  bool bootstrap;
70  float max_samples;
77  uint64_t seed;
83  int n_streams;
85 };
86 
87 /* Update labels so they are unique from 0 to n_unique_vals.
88  Create an old_label to new_label map per random forest.
89 */
90 void preprocess_labels(int n_rows,
91  std::vector<int>& labels,
92  std::map<int, int>& labels_map,
93  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
94 
95 /* Revert preprocessing effect, if needed. */
96 void postprocess_labels(int n_rows,
97  std::vector<int>& labels,
98  std::map<int, int>& labels_map,
99  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
100 
101 template <class T, class L>
103  std::vector<std::shared_ptr<DT::TreeMetaDataNode<T, L>>> trees;
105 
109  int n_features = 0;
110 };
111 
112 template <class T, class L>
114 
115 template <class T, class L>
117 
118 template <class T, class L>
120 
121 template <class T, class L>
122 std::string get_rf_json(const RandomForestMetaData<T, L>* forest);
123 
124 template <class T, class L>
126  const RandomForestMetaData<T, L>* forest,
127  int num_features);
128 
136 template <class T, class L>
137 void compute_feature_importances(const RandomForestMetaData<T, L>* forest, T* importances);
138 
139 // ----------------------------- Classification ----------------------------------- //
140 
143 
144 void fit(const raft::handle_t& user_handle,
145  RandomForestClassifierF* forest,
146  float* input,
147  int n_rows,
148  int n_cols,
149  int* labels,
150  int n_unique_labels,
151  RF_params rf_params,
152  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info,
153  bool* bootstrap_masks = nullptr);
154 void fit(const raft::handle_t& user_handle,
155  RandomForestClassifierD* forest,
156  double* input,
157  int n_rows,
158  int n_cols,
159  int* labels,
160  int n_unique_labels,
161  RF_params rf_params,
162  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info,
163  bool* bootstrap_masks = nullptr);
164 
165 template <typename T, typename L>
166 void fit_treelite(const raft::handle_t& user_handle,
167  TreeliteModelHandle* model,
168  T* input,
169  int n_rows,
170  int n_cols,
171  L* labels,
172  int n_unique_labels,
173  RF_params rf_params,
174  bool* bootstrap_masks,
175  T* feature_importances,
176  rapids_logger::level_enum verbosity);
177 
178 void predict(const raft::handle_t& user_handle,
179  const RandomForestClassifierF* forest,
180  const float* input,
181  int n_rows,
182  int n_cols,
183  int* predictions,
184  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
185 void predict(const raft::handle_t& user_handle,
186  const RandomForestClassifierD* forest,
187  const double* input,
188  int n_rows,
189  int n_cols,
190  int* predictions,
191  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
192 
193 RF_metrics score(const raft::handle_t& user_handle,
194  const RandomForestClassifierF* forest,
195  const int* ref_labels,
196  int n_rows,
197  const int* predictions,
198  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
199 RF_metrics score(const raft::handle_t& user_handle,
200  const RandomForestClassifierD* forest,
201  const int* ref_labels,
202  int n_rows,
203  const int* predictions,
204  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
205 
206 RF_params set_rf_params(int max_depth,
207  int max_leaves,
208  float max_features,
209  int max_n_bins,
210  int min_samples_leaf,
211  int min_samples_split,
212  float min_impurity_decrease,
213  bool bootstrap,
214  int n_trees,
215  float max_samples,
216  uint64_t seed,
217  CRITERION split_criterion,
218  int cfg_n_streams,
219  int max_batch_size);
220 
221 // ----------------------------- Regression ----------------------------------- //
222 
225 
226 void fit(const raft::handle_t& user_handle,
227  RandomForestRegressorF* forest,
228  float* input,
229  int n_rows,
230  int n_cols,
231  float* labels,
232  RF_params rf_params,
233  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info,
234  bool* bootstrap_masks = nullptr);
235 void fit(const raft::handle_t& user_handle,
236  RandomForestRegressorD* forest,
237  double* input,
238  int n_rows,
239  int n_cols,
240  double* labels,
241  RF_params rf_params,
242  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info,
243  bool* bootstrap_masks = nullptr);
244 
245 template <typename T, typename L>
246 void fit_treelite(const raft::handle_t& user_handle,
247  TreeliteModelHandle* model,
248  T* input,
249  int n_rows,
250  int n_cols,
251  L* labels,
252  RF_params rf_params,
253  bool* bootstrap_masks,
254  T* feature_importances,
255  rapids_logger::level_enum verbosity);
256 
257 void predict(const raft::handle_t& user_handle,
258  const RandomForestRegressorF* forest,
259  const float* input,
260  int n_rows,
261  int n_cols,
262  float* predictions,
263  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
264 void predict(const raft::handle_t& user_handle,
265  const RandomForestRegressorD* forest,
266  const double* input,
267  int n_rows,
268  int n_cols,
269  double* predictions,
270  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
271 
272 RF_metrics score(const raft::handle_t& user_handle,
273  const RandomForestRegressorF* forest,
274  const float* ref_labels,
275  int n_rows,
276  const float* predictions,
277  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
278 RF_metrics score(const raft::handle_t& user_handle,
279  const RandomForestRegressorD* forest,
280  const double* ref_labels,
281  int n_rows,
282  const double* predictions,
283  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
284 }; // namespace ML
Definition: dbscan.hpp:18
void preprocess_labels(int n_rows, std::vector< int > &labels, std::map< int, int > &labels_map, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
std::string get_rf_json(const RandomForestMetaData< T, L > *forest)
std::string get_rf_summary_text(const RandomForestMetaData< T, L > *forest)
RandomForestMetaData< float, int > RandomForestClassifierF
Definition: randomforest.hpp:141
void print(const RF_metrics rf_metrics)
void delete_rf_metadata(RandomForestMetaData< T, L > *forest)
RF_type
Definition: randomforest.hpp:21
@ REGRESSION
Definition: randomforest.hpp:23
@ CLASSIFICATION
Definition: randomforest.hpp:22
void compute_feature_importances(const RandomForestMetaData< T, L > *forest, T *importances)
Compute the feature importances of the trained RandomForest model.
RF_metrics set_all_rf_metrics(RF_type rf_type, float accuracy, double mean_abs_error, double mean_squared_error, double median_abs_error)
RandomForestMetaData< double, double > RandomForestRegressorD
Definition: randomforest.hpp:224
RandomForestMetaData< float, float > RandomForestRegressorF
Definition: randomforest.hpp:223
void postprocess_labels(int n_rows, std::vector< int > &labels, std::map< int, int > &labels_map, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
void build_treelite_forest(TreeliteModelHandle *model, const RandomForestMetaData< T, L > *forest, int num_features)
CRITERION
Definition: algo_helper.h:9
RF_metrics set_rf_metrics_classification(float accuracy)
std::string get_rf_detailed_text(const RandomForestMetaData< T, L > *forest)
RF_metrics score(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const int *ref_labels, int n_rows, const int *predictions, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
RF_params set_rf_params(int max_depth, int max_leaves, float max_features, int max_n_bins, int min_samples_leaf, int min_samples_split, float min_impurity_decrease, bool bootstrap, int n_trees, float max_samples, uint64_t seed, CRITERION split_criterion, int cfg_n_streams, int max_batch_size)
RF_metrics set_rf_metrics_regression(double mean_abs_error, double mean_squared_error, double median_abs_error)
void fit_treelite(const raft::handle_t &user_handle, TreeliteModelHandle *model, T *input, int n_rows, int n_cols, L *labels, int n_unique_labels, RF_params rf_params, bool *bootstrap_masks, T *feature_importances, rapids_logger::level_enum verbosity)
void predict(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const float *input, int n_rows, int n_cols, int *predictions, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
RandomForestMetaData< double, int > RandomForestClassifierD
Definition: randomforest.hpp:142
task_category
Definition: randomforest.hpp:26
@ REGRESSION_MODEL
Definition: randomforest.hpp:26
@ CLASSIFICATION_MODEL
Definition: randomforest.hpp:26
void fit(const raft::handle_t &user_handle, RandomForestClassifierF *forest, float *input, int n_rows, int n_cols, int *labels, int n_unique_labels, RF_params rf_params, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info, bool *bootstrap_masks=nullptr)
Definition: dbscan.hpp:14
Definition: decisiontree.hpp:18
Definition: randomforest.hpp:28
RF_type rf_type
Definition: randomforest.hpp:29
double mean_squared_error
Definition: randomforest.hpp:36
double median_abs_error
Definition: randomforest.hpp:37
float accuracy
Definition: randomforest.hpp:32
double mean_abs_error
Definition: randomforest.hpp:35
Definition: randomforest.hpp:51
uint64_t seed
Definition: randomforest.hpp:77
int n_streams
Definition: randomforest.hpp:83
DT::DecisionTreeParams tree_params
Definition: randomforest.hpp:84
bool bootstrap
Definition: randomforest.hpp:66
int n_trees
Definition: randomforest.hpp:55
float max_samples
Definition: randomforest.hpp:70
Definition: randomforest.hpp:102
int n_features
Definition: randomforest.hpp:109
RF_params rf_params
Definition: randomforest.hpp:104
std::vector< std::shared_ptr< DT::TreeMetaDataNode< T, L > > > trees
Definition: randomforest.hpp:103
void * TreeliteModelHandle
Definition: treelite_defs.hpp:12