53 double mean_abs_error,
54 double mean_squared_error,
55 double median_abs_error);
58 double mean_squared_error,
59 double median_abs_error);
102 std::vector<int>& labels,
103 std::map<int, int>& labels_map,
108 std::vector<int>& labels,
109 std::map<int, int>& labels_map,
112 template <
class T,
class L>
114 std::vector<std::shared_ptr<DT::TreeMetaDataNode<T, L>>>
trees;
118 template <
class T,
class L>
121 template <
class T,
class L>
124 template <
class T,
class L>
127 template <
class T,
class L>
130 template <
class T,
class L>
142 void fit(
const raft::handle_t& user_handle,
151 void fit(
const raft::handle_t& user_handle,
161 void predict(
const raft::handle_t& user_handle,
168 void predict(
const raft::handle_t& user_handle,
178 const int* ref_labels,
180 const int* predictions,
184 const int* ref_labels,
186 const int* predictions,
193 int min_samples_leaf,
194 int min_samples_split,
195 float min_impurity_decrease,
209 void fit(
const raft::handle_t& user_handle,
217 void fit(
const raft::handle_t& user_handle,
226 void predict(
const raft::handle_t& user_handle,
233 void predict(
const raft::handle_t& user_handle,
243 const float* ref_labels,
245 const float* predictions,
249 const double* ref_labels,
251 const double* predictions,
#define CUML_LEVEL_INFO
Definition: log_levels.hpp:28
Definition: dbscan.hpp:30
void postprocess_labels(int n_rows, std::vector< int > &labels, std::map< int, int > &labels_map, int verbosity=CUML_LEVEL_INFO)
void predict(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const float *input, int n_rows, int n_cols, int *predictions, int verbosity=CUML_LEVEL_INFO)
std::string get_rf_json(const RandomForestMetaData< T, L > *forest)
void fit(const raft::handle_t &user_handle, RandomForestClassifierF *&forest, float *input, int n_rows, int n_cols, int *labels, int n_unique_labels, RF_params rf_params, int verbosity=CUML_LEVEL_INFO)
std::string get_rf_summary_text(const RandomForestMetaData< T, L > *forest)
RandomForestMetaData< float, int > RandomForestClassifierF
Definition: randomforest.hpp:139
void print(const RF_metrics rf_metrics)
void delete_rf_metadata(RandomForestMetaData< T, L > *forest)
RF_type
Definition: randomforest.hpp:32
@ REGRESSION
Definition: randomforest.hpp:34
@ CLASSIFICATION
Definition: randomforest.hpp:33
RF_metrics set_all_rf_metrics(RF_type rf_type, float accuracy, double mean_abs_error, double mean_squared_error, double median_abs_error)
RandomForestMetaData< double, double > RandomForestRegressorD
Definition: randomforest.hpp:207
RandomForestMetaData< float, float > RandomForestRegressorF
Definition: randomforest.hpp:206
void build_treelite_forest(TreeliteModelHandle *model, const RandomForestMetaData< T, L > *forest, int num_features)
CRITERION
Definition: algo_helper.h:20
RF_metrics set_rf_metrics_classification(float accuracy)
std::string get_rf_detailed_text(const RandomForestMetaData< T, L > *forest)
void preprocess_labels(int n_rows, std::vector< int > &labels, std::map< int, int > &labels_map, int verbosity=CUML_LEVEL_INFO)
RF_params set_rf_params(int max_depth, int max_leaves, float max_features, int max_n_bins, int min_samples_leaf, int min_samples_split, float min_impurity_decrease, bool bootstrap, int n_trees, float max_samples, uint64_t seed, CRITERION split_criterion, int cfg_n_streams, int max_batch_size)
RF_metrics set_rf_metrics_regression(double mean_abs_error, double mean_squared_error, double median_abs_error)
RandomForestMetaData< double, int > RandomForestClassifierD
Definition: randomforest.hpp:140
RF_metrics score(const raft::handle_t &user_handle, const RandomForestClassifierF *forest, const int *ref_labels, int n_rows, const int *predictions, int verbosity=CUML_LEVEL_INFO)
task_category
Definition: randomforest.hpp:37
@ REGRESSION_MODEL
Definition: randomforest.hpp:37
@ CLASSIFICATION_MODEL
Definition: randomforest.hpp:37
TreeliteModelHandle concatenate_trees(std::vector< TreeliteModelHandle > treelite_handles)
Definition: dbscan.hpp:26
Definition: decisiontree.hpp:29
Definition: randomforest.hpp:39
RF_type rf_type
Definition: randomforest.hpp:40
double mean_squared_error
Definition: randomforest.hpp:47
double median_abs_error
Definition: randomforest.hpp:48
float accuracy
Definition: randomforest.hpp:43
double mean_abs_error
Definition: randomforest.hpp:46
Definition: randomforest.hpp:62
uint64_t seed
Definition: randomforest.hpp:88
int n_streams
Definition: randomforest.hpp:94
DT::DecisionTreeParams tree_params
Definition: randomforest.hpp:95
bool bootstrap
Definition: randomforest.hpp:77
int n_trees
Definition: randomforest.hpp:66
float max_samples
Definition: randomforest.hpp:81
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23