10 #include <treelite/tree.h>
21 template <
typename tl_threshold_t,
typename tl_output_t>
29 bool is_leaf()
const override {
return tl_tree_.IsLeaf(node_id_); }
35 result = tl_tree_.LeftChild(node_id_);
37 result = tl_tree_.RightChild(node_id_);
46 result = tl_tree_.RightChild(node_id_);
48 result = tl_tree_.LeftChild(node_id_);
55 auto get_feature()
const {
return tl_tree_.SplitIndex(node_id_); }
59 auto tl_operator = tl_tree_.ComparisonOp(node_id_);
60 return tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kLE;
65 return tl_tree_.NodeType(node_id_) == treelite::TreeNodeType::kCategoricalTestNode;
70 auto threshold()
const {
return tl_tree_.Threshold(node_id_); }
76 #pragma GCC diagnostic push
77 #pragma GCC diagnostic ignored "-Wfree-nonheap-object"
80 auto result = std::remove_const_t<std::remove_reference_t<decltype(
get_categories()[0])>>{};
83 if (categories.size() != 0) {
84 result = *std::max_element(std::begin(categories), std::end(categories)) + 1;
89 #pragma GCC diagnostic pop
93 auto result = std::vector<tl_output_t>{};
94 if (tl_tree_.HasLeafVector(node_id_)) {
95 result = tl_tree_.LeafVector(node_id_);
97 result.push_back(tl_tree_.LeafValue(node_id_));
105 treelite::Tree<tl_threshold_t, tl_output_t>
const& tl_tree_;
108 auto left_is_hot()
const
112 if (tl_tree_.CategoryListRightChild(node_id_)) { result =
true; }
114 auto tl_operator = tl_tree_.ComparisonOp(node_id_);
115 if (tl_operator == treelite::Operator::kLT || tl_operator == treelite::Operator::kLE) {
117 }
else if (tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kGE) {
120 throw traversal_exception(
"Unrecognized Treelite operator");
127 template <
typename tl_threshold_t,
typename tl_output_t>
129 :
public traversal_forest<treelite_traversal_node<tl_threshold_t, tl_output_t>> {
141 auto result = std::vector<node_uid_type>{};
142 result.reserve(tl_model.GetNumTree());
143 for (
auto i = std::size_t{}; i < tl_model.GetNumTree(); ++i) {
154 return node_type{tl_model_.trees[tree_id], node_id};
158 treelite::ModelPreset<tl_threshold_t, tl_output_t>
const& tl_model_;
161 template <
typename lambda_t>
165 [&lambda](
auto&& concrete_tl_model) {
166 std::for_each(std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), lambda);
171 template <
typename iter_t,
typename lambda_t>
172 void tree_transform(treelite::Model
const& tl_model, iter_t out_iter, lambda_t&& lambda)
175 [&lambda, out_iter](
auto&& concrete_tl_model) {
177 std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), out_iter, lambda);
182 template <
typename T,
typename lambda_t>
186 [&lambda, init](
auto&& concrete_tl_model) {
187 return std::accumulate(
188 std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), init, lambda);
193 template <forest_order order,
typename lambda_t>
197 [&lambda](
auto&& concrete_tl_model) {
203 template <forest_order order,
typename iter_t,
typename lambda_t>
204 void node_transform(treelite::Model
const& tl_model, iter_t output_iter, lambda_t&& lambda)
206 node_for_each<order>(
208 [&output_iter, &lambda](
auto&& tree_id,
auto&& node,
auto&& depth,
auto&& parent_index) {
209 *output_iter = lambda(tree_id, node, depth, parent_index);
214 template <forest_order order,
typename T,
typename lambda_t>
218 node_for_each<order>(
219 tl_model, [&result, &lambda](
auto&& tree_id,
auto&& node,
auto&& depth,
auto&& parent_index) {
220 result = lambda(result, tree_id, node, depth, parent_index);
void node_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite.hpp:194
void tree_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite.hpp:162
int TREELITE_NODE_ID_T
Definition: treelite.hpp:19
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:183
void node_transform(treelite::Model const &tl_model, iter_t output_iter, lambda_t &&lambda)
Definition: treelite.hpp:204
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:172
auto node_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:215
void transform(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
Definition: dbscan.hpp:18
Definition: traversal_forest.hpp:64
std::size_t tree_id_type
Definition: traversal_forest.hpp:67
typename node_type::id_type node_id_type
Definition: traversal_forest.hpp:66
std::pair< tree_id_type, node_id_type > node_uid_type
Definition: traversal_forest.hpp:68
treelite_traversal_node< tl_threshold_t, tl_output_t > node_type
Definition: traversal_forest.hpp:65
Definition: traversal_node.hpp:25
TREELITE_NODE_ID_T id_type
Definition: traversal_node.hpp:27
Definition: treelite.hpp:129
treelite_traversal_forest(treelite::ModelPreset< tl_threshold_t, tl_output_t > const &tl_model)
Definition: treelite.hpp:139
node_type get_node(tree_id_type tree_id, node_id_type node_id) const override
Definition: treelite.hpp:152
typename base_type::node_type node_type
Definition: treelite.hpp:134
typename base_type::tree_id_type tree_id_type
Definition: treelite.hpp:136
typename base_type::node_id_type node_id_type
Definition: treelite.hpp:135
typename base_type::node_uid_type node_uid_type
Definition: treelite.hpp:137
Definition: treelite.hpp:22
auto get_output() const
Definition: treelite.hpp:91
auto threshold() const
Definition: treelite.hpp:70
auto is_inclusive() const
Definition: treelite.hpp:57
bool is_leaf() const override
Definition: treelite.hpp:29
auto is_categorical() const
Definition: treelite.hpp:63
id_type hot_child() const override
Definition: treelite.hpp:31
auto get_feature() const
Definition: treelite.hpp:55
treelite_traversal_node(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, id_type node_id)
Definition: treelite.hpp:23
auto get_treelite_id() const
Definition: treelite.hpp:102
auto default_distant() const
Definition: treelite.hpp:53
auto max_num_categories() const
Definition: treelite.hpp:78
auto get_categories() const
Definition: treelite.hpp:68
id_type distant_child() const override
Definition: treelite.hpp:42