randomforest.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <cuml/common/logger.hpp>
22 
23 #include <map>
24 #include <memory>
25 
26 namespace raft {
27 class handle_t; // forward decl
28 }
29 
30 namespace ML {
31 
32 enum RF_type {
35 };
36 
38 
39 struct RF_metrics {
41 
42  // Classification metrics
43  float accuracy;
44 
45  // Regression metrics
49 };
50 
52  float accuracy,
53  double mean_abs_error,
54  double mean_squared_error,
55  double median_abs_error);
57 RF_metrics set_rf_metrics_regression(double mean_abs_error,
58  double mean_squared_error,
59  double median_abs_error);
60 void print(const RF_metrics rf_metrics);
61 
62 struct RF_params {
66  int n_trees;
77  bool bootstrap;
81  float max_samples;
88  uint64_t seed;
94  int n_streams;
96 };
97 
98 /* Update labels so they are unique from 0 to n_unique_vals.
99  Create an old_label to new_label map per random forest.
100 */
101 void preprocess_labels(int n_rows,
102  std::vector<int>& labels,
103  std::map<int, int>& labels_map,
104  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
105 
106 /* Revert preprocessing effect, if needed. */
107 void postprocess_labels(int n_rows,
108  std::vector<int>& labels,
109  std::map<int, int>& labels_map,
110  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
111 
112 template <class T, class L>
114  std::vector<std::shared_ptr<DT::TreeMetaDataNode<T, L>>> trees;
116 };
117 
118 template <class T, class L>
120 
121 template <class T, class L>
123 
124 template <class T, class L>
126 
127 template <class T, class L>
128 std::string get_rf_json(const RandomForestMetaData<T, L>* forest);
129 
130 template <class T, class L>
132  const RandomForestMetaData<T, L>* forest,
133  int num_features);
134 
135 // ----------------------------- Classification ----------------------------------- //
136 
139 
140 void fit(const raft::handle_t& user_handle,
141  RandomForestClassifierF*& forest,
142  float* input,
143  int n_rows,
144  int n_cols,
145  int* labels,
146  int n_unique_labels,
147  RF_params rf_params,
148  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
149 void fit(const raft::handle_t& user_handle,
150  RandomForestClassifierD*& forest,
151  double* input,
152  int n_rows,
153  int n_cols,
154  int* labels,
155  int n_unique_labels,
156  RF_params rf_params,
157  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
158 
159 void predict(const raft::handle_t& user_handle,
160  const RandomForestClassifierF* forest,
161  const float* input,
162  int n_rows,
163  int n_cols,
164  int* predictions,
165  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
166 void predict(const raft::handle_t& user_handle,
167  const RandomForestClassifierD* forest,
168  const double* input,
169  int n_rows,
170  int n_cols,
171  int* predictions,
172  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
173 
174 RF_metrics score(const raft::handle_t& user_handle,
175  const RandomForestClassifierF* forest,
176  const int* ref_labels,
177  int n_rows,
178  const int* predictions,
179  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
180 RF_metrics score(const raft::handle_t& user_handle,
181  const RandomForestClassifierD* forest,
182  const int* ref_labels,
183  int n_rows,
184  const int* predictions,
185  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
186 
187 RF_params set_rf_params(int max_depth,
188  int max_leaves,
189  float max_features,
190  int max_n_bins,
191  int min_samples_leaf,
192  int min_samples_split,
193  float min_impurity_decrease,
194  bool bootstrap,
195  int n_trees,
196  float max_samples,
197  uint64_t seed,
198  CRITERION split_criterion,
199  int cfg_n_streams,
200  int max_batch_size);
201 
202 // ----------------------------- Regression ----------------------------------- //
203 
206 
207 void fit(const raft::handle_t& user_handle,
208  RandomForestRegressorF*& forest,
209  float* input,
210  int n_rows,
211  int n_cols,
212  float* labels,
213  RF_params rf_params,
214  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
215 void fit(const raft::handle_t& user_handle,
216  RandomForestRegressorD*& forest,
217  double* input,
218  int n_rows,
219  int n_cols,
220  double* labels,
221  RF_params rf_params,
222  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
223 
224 void predict(const raft::handle_t& user_handle,
225  const RandomForestRegressorF* forest,
226  const float* input,
227  int n_rows,
228  int n_cols,
229  float* predictions,
230  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
231 void predict(const raft::handle_t& user_handle,
232  const RandomForestRegressorD* forest,
233  const double* input,
234  int n_rows,
235  int n_cols,
236  double* predictions,
237  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
238 
239 RF_metrics score(const raft::handle_t& user_handle,
240  const RandomForestRegressorF* forest,
241  const float* ref_labels,
242  int n_rows,
243  const float* predictions,
244  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
245 RF_metrics score(const raft::handle_t& user_handle,
246  const RandomForestRegressorD* forest,
247  const double* ref_labels,
248  int n_rows,
249  const double* predictions,
250  rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
251 }; // namespace ML
Definition: dbscan.hpp:29
void preprocess_labels(int n_rows, std::vector< int > &labels, std::map< int, int > &labels_map, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
void fit(const raft::handle_t &user_handle, RandomForestClassifierF *&forest, float *input, int n_rows, int n_cols, int *labels, int n_unique_labels, RF_params rf_params, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
std::string get_rf_json(const RandomForestMetaData< T, L > *forest)
std::string get_rf_summary_text(const RandomForestMetaData< T, L > *forest)
RandomForestMetaData< float, int > RandomForestClassifierF
Definition: randomforest.hpp:137
void print(const RF_metrics rf_metrics)
void delete_rf_metadata(RandomForestMetaData< T, L > *forest)
RF_type
Definition: randomforest.hpp:32
@ REGRESSION
Definition: randomforest.hpp:34
@ CLASSIFICATION
Definition: randomforest.hpp:33
RF_metrics set_all_rf_metrics(RF_type rf_type, float accuracy, double mean_abs_error, double mean_squared_error, double median_abs_error)
RandomForestMetaData< double, double > RandomForestRegressorD
Definition: randomforest.hpp:205
RandomForestMetaData< float, float > RandomForestRegressorF
Definition: randomforest.hpp:204
void postprocess_labels(int n_rows, std::vector< int > &labels, std::map< int, int > &labels_map, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
void build_treelite_forest(TreeliteModelHandle *model, const RandomForestMetaData< T, L > *forest, int num_features)
CRITERION
Definition: algo_helper.h:20
RF_metrics set_rf_metrics_classification(float accuracy)
std::string get_rf_detailed_text(const RandomForestMetaData< T, L > *forest)
RF_metrics score(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const int *ref_labels, int n_rows, const int *predictions, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
RF_params set_rf_params(int max_depth, int max_leaves, float max_features, int max_n_bins, int min_samples_leaf, int min_samples_split, float min_impurity_decrease, bool bootstrap, int n_trees, float max_samples, uint64_t seed, CRITERION split_criterion, int cfg_n_streams, int max_batch_size)
RF_metrics set_rf_metrics_regression(double mean_abs_error, double mean_squared_error, double median_abs_error)
void predict(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const float *input, int n_rows, int n_cols, int *predictions, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
RandomForestMetaData< double, int > RandomForestClassifierD
Definition: randomforest.hpp:138
task_category
Definition: randomforest.hpp:37
@ REGRESSION_MODEL
Definition: randomforest.hpp:37
@ CLASSIFICATION_MODEL
Definition: randomforest.hpp:37
Definition: dbscan.hpp:25
Definition: decisiontree.hpp:29
Definition: randomforest.hpp:39
RF_type rf_type
Definition: randomforest.hpp:40
double mean_squared_error
Definition: randomforest.hpp:47
double median_abs_error
Definition: randomforest.hpp:48
float accuracy
Definition: randomforest.hpp:43
double mean_abs_error
Definition: randomforest.hpp:46
Definition: randomforest.hpp:62
uint64_t seed
Definition: randomforest.hpp:88
int n_streams
Definition: randomforest.hpp:94
DT::DecisionTreeParams tree_params
Definition: randomforest.hpp:95
bool bootstrap
Definition: randomforest.hpp:77
int n_trees
Definition: randomforest.hpp:66
float max_samples
Definition: randomforest.hpp:81
Definition: randomforest.hpp:113
RF_params rf_params
Definition: randomforest.hpp:115
std::vector< std::shared_ptr< DT::TreeMetaDataNode< T, L > > > trees
Definition: randomforest.hpp:114
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23