53 double mean_abs_error,
54 double mean_squared_error,
55 double median_abs_error);
58 double mean_squared_error,
59 double median_abs_error);
102 std::vector<int>& labels,
103 std::map<int, int>& labels_map,
104 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
108 std::vector<int>& labels,
109 std::map<int, int>& labels_map,
110 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
112 template <
class T,
class L>
114 std::vector<std::shared_ptr<DT::TreeMetaDataNode<T, L>>>
trees;
118 template <
class T,
class L>
121 template <
class T,
class L>
124 template <
class T,
class L>
127 template <
class T,
class L>
130 template <
class T,
class L>
140 void fit(
const raft::handle_t& user_handle,
148 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
149 void fit(
const raft::handle_t& user_handle,
157 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
159 void predict(
const raft::handle_t& user_handle,
165 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
166 void predict(
const raft::handle_t& user_handle,
172 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
176 const int* ref_labels,
178 const int* predictions,
179 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
182 const int* ref_labels,
184 const int* predictions,
185 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
191 int min_samples_leaf,
192 int min_samples_split,
193 float min_impurity_decrease,
207 void fit(
const raft::handle_t& user_handle,
214 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
215 void fit(
const raft::handle_t& user_handle,
222 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
224 void predict(
const raft::handle_t& user_handle,
230 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
231 void predict(
const raft::handle_t& user_handle,
237 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
241 const float* ref_labels,
243 const float* predictions,
244 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
247 const double* ref_labels,
249 const double* predictions,
250 rapids_logger::level_enum verbosity = rapids_logger::level_enum::info);
Definition: dbscan.hpp:29
void preprocess_labels(int n_rows, std::vector< int > &labels, std::map< int, int > &labels_map, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
void fit(const raft::handle_t &user_handle, RandomForestClassifierF *&forest, float *input, int n_rows, int n_cols, int *labels, int n_unique_labels, RF_params rf_params, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
std::string get_rf_json(const RandomForestMetaData< T, L > *forest)
std::string get_rf_summary_text(const RandomForestMetaData< T, L > *forest)
RandomForestMetaData< float, int > RandomForestClassifierF
Definition: randomforest.hpp:137
void print(const RF_metrics rf_metrics)
void delete_rf_metadata(RandomForestMetaData< T, L > *forest)
RF_type
Definition: randomforest.hpp:32
@ REGRESSION
Definition: randomforest.hpp:34
@ CLASSIFICATION
Definition: randomforest.hpp:33
RF_metrics set_all_rf_metrics(RF_type rf_type, float accuracy, double mean_abs_error, double mean_squared_error, double median_abs_error)
RandomForestMetaData< double, double > RandomForestRegressorD
Definition: randomforest.hpp:205
RandomForestMetaData< float, float > RandomForestRegressorF
Definition: randomforest.hpp:204
void postprocess_labels(int n_rows, std::vector< int > &labels, std::map< int, int > &labels_map, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
void build_treelite_forest(TreeliteModelHandle *model, const RandomForestMetaData< T, L > *forest, int num_features)
CRITERION
Definition: algo_helper.h:20
RF_metrics set_rf_metrics_classification(float accuracy)
std::string get_rf_detailed_text(const RandomForestMetaData< T, L > *forest)
RF_metrics score(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const int *ref_labels, int n_rows, const int *predictions, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
RF_params set_rf_params(int max_depth, int max_leaves, float max_features, int max_n_bins, int min_samples_leaf, int min_samples_split, float min_impurity_decrease, bool bootstrap, int n_trees, float max_samples, uint64_t seed, CRITERION split_criterion, int cfg_n_streams, int max_batch_size)
RF_metrics set_rf_metrics_regression(double mean_abs_error, double mean_squared_error, double median_abs_error)
void predict(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const float *input, int n_rows, int n_cols, int *predictions, rapids_logger::level_enum verbosity=rapids_logger::level_enum::info)
RandomForestMetaData< double, int > RandomForestClassifierD
Definition: randomforest.hpp:138
task_category
Definition: randomforest.hpp:37
@ REGRESSION_MODEL
Definition: randomforest.hpp:37
@ CLASSIFICATION_MODEL
Definition: randomforest.hpp:37
Definition: dbscan.hpp:25
Definition: decisiontree.hpp:29
Definition: randomforest.hpp:39
RF_type rf_type
Definition: randomforest.hpp:40
double mean_squared_error
Definition: randomforest.hpp:47
double median_abs_error
Definition: randomforest.hpp:48
float accuracy
Definition: randomforest.hpp:43
double mean_abs_error
Definition: randomforest.hpp:46
Definition: randomforest.hpp:62
uint64_t seed
Definition: randomforest.hpp:88
int n_streams
Definition: randomforest.hpp:94
DT::DecisionTreeParams tree_params
Definition: randomforest.hpp:95
bool bootstrap
Definition: randomforest.hpp:77
int n_trees
Definition: randomforest.hpp:66
float max_samples
Definition: randomforest.hpp:81
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23