42 double mean_abs_error,
43 double mean_squared_error,
44 double median_abs_error);
47 double mean_squared_error,
48 double median_abs_error);
91 std::vector<int>& labels,
92 std::map<int, int>& labels_map,
93 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
97 std::vector<int>& labels,
98 std::map<int, int>& labels_map,
99 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
101 template <
class T,
class L>
103 std::vector<std::shared_ptr<DT::TreeMetaDataNode<T, L>>>
trees;
112 template <
class T,
class L>
115 template <
class T,
class L>
118 template <
class T,
class L>
121 template <
class T,
class L>
124 template <
class T,
class L>
136 template <
class T,
class L>
144 void fit(
const raft::handle_t& user_handle,
152 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info,
153 bool* bootstrap_masks =
nullptr);
154 void fit(
const raft::handle_t& user_handle,
162 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info,
163 bool* bootstrap_masks =
nullptr);
165 template <
typename T,
typename L>
174 bool* bootstrap_masks,
175 T* feature_importances,
176 rapids_logger::level_enum verbosity);
178 void predict(
const raft::handle_t& user_handle,
184 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
185 void predict(
const raft::handle_t& user_handle,
191 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
195 const int* ref_labels,
197 const int* predictions,
198 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
201 const int* ref_labels,
203 const int* predictions,
204 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
210 int min_samples_leaf,
211 int min_samples_split,
212 float min_impurity_decrease,
226 void fit(
const raft::handle_t& user_handle,
233 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info,
234 bool* bootstrap_masks =
nullptr);
235 void fit(
const raft::handle_t& user_handle,
242 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info,
243 bool* bootstrap_masks =
nullptr);
245 template <
typename T,
typename L>
253 bool* bootstrap_masks,
254 T* feature_importances,
255 rapids_logger::level_enum verbosity);
257 void predict(
const raft::handle_t& user_handle,
263 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
264 void predict(
const raft::handle_t& user_handle,
270 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
274 const float* ref_labels,
276 const float* predictions,
277 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
280 const double* ref_labels,
282 const double* predictions,
283 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
Definition: dbscan.hpp:18
void preprocess_labels(int n_rows, std::vector< int > &labels, std::map< int, int > &labels_map, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
std::string get_rf_json(const RandomForestMetaData< T, L > *forest)
std::string get_rf_summary_text(const RandomForestMetaData< T, L > *forest)
RandomForestMetaData< float, int > RandomForestClassifierF
Definition: randomforest.hpp:141
void print(const RF_metrics rf_metrics)
void delete_rf_metadata(RandomForestMetaData< T, L > *forest)
RF_type
Definition: randomforest.hpp:21
@ REGRESSION
Definition: randomforest.hpp:23
@ CLASSIFICATION
Definition: randomforest.hpp:22
void compute_feature_importances(const RandomForestMetaData< T, L > *forest, T *importances)
Compute the feature importances of the trained RandomForest model.
RF_metrics set_all_rf_metrics(RF_type rf_type, float accuracy, double mean_abs_error, double mean_squared_error, double median_abs_error)
RandomForestMetaData< double, double > RandomForestRegressorD
Definition: randomforest.hpp:224
RandomForestMetaData< float, float > RandomForestRegressorF
Definition: randomforest.hpp:223
void postprocess_labels(int n_rows, std::vector< int > &labels, std::map< int, int > &labels_map, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
void build_treelite_forest(TreeliteModelHandle *model, const RandomForestMetaData< T, L > *forest, int num_features)
CRITERION
Definition: algo_helper.h:9
RF_metrics set_rf_metrics_classification(float accuracy)
std::string get_rf_detailed_text(const RandomForestMetaData< T, L > *forest)
RF_metrics score(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const int *ref_labels, int n_rows, const int *predictions, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
RF_params set_rf_params(int max_depth, int max_leaves, float max_features, int max_n_bins, int min_samples_leaf, int min_samples_split, float min_impurity_decrease, bool bootstrap, int n_trees, float max_samples, uint64_t seed, CRITERION split_criterion, int cfg_n_streams, int max_batch_size)
RF_metrics set_rf_metrics_regression(double mean_abs_error, double mean_squared_error, double median_abs_error)
void fit_treelite(const raft::handle_t &user_handle, TreeliteModelHandle *model, T *input, int n_rows, int n_cols, L *labels, int n_unique_labels, RF_params rf_params, bool *bootstrap_masks, T *feature_importances, rapids_logger::level_enum verbosity)
void predict(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const float *input, int n_rows, int n_cols, int *predictions, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
RandomForestMetaData< double, int > RandomForestClassifierD
Definition: randomforest.hpp:142
task_category
Definition: randomforest.hpp:26
@ REGRESSION_MODEL
Definition: randomforest.hpp:26
@ CLASSIFICATION_MODEL
Definition: randomforest.hpp:26
void fit(const raft::handle_t &user_handle, RandomForestClassifierF *forest, float *input, int n_rows, int n_cols, int *labels, int n_unique_labels, RF_params rf_params, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info, bool *bootstrap_masks=nullptr)
Definition: dbscan.hpp:14
Definition: decisiontree.hpp:18
Definition: randomforest.hpp:28
RF_type rf_type
Definition: randomforest.hpp:29
double mean_squared_error
Definition: randomforest.hpp:36
double median_abs_error
Definition: randomforest.hpp:37
float accuracy
Definition: randomforest.hpp:32
double mean_abs_error
Definition: randomforest.hpp:35
Definition: randomforest.hpp:51
uint64_t seed
Definition: randomforest.hpp:77
int n_streams
Definition: randomforest.hpp:83
DT::DecisionTreeParams tree_params
Definition: randomforest.hpp:84
bool bootstrap
Definition: randomforest.hpp:66
int n_trees
Definition: randomforest.hpp:55
float max_samples
Definition: randomforest.hpp:70
void * TreeliteModelHandle
Definition: treelite_defs.hpp:12