21 #include <treelite/tree.h>
32 template <
typename tl_threshold_t,
typename tl_output_t>
40 bool is_leaf()
const override {
return tl_tree_.IsLeaf(node_id_); }
46 result = tl_tree_.LeftChild(node_id_);
48 result = tl_tree_.RightChild(node_id_);
57 result = tl_tree_.RightChild(node_id_);
59 result = tl_tree_.LeftChild(node_id_);
66 auto get_feature()
const {
return tl_tree_.SplitIndex(node_id_); }
70 auto tl_operator = tl_tree_.ComparisonOp(node_id_);
71 return tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kLE;
76 return tl_tree_.NodeType(node_id_) == treelite::TreeNodeType::kCategoricalTestNode;
81 auto threshold()
const {
return tl_tree_.Threshold(node_id_); }
85 auto result = std::remove_const_t<std::remove_reference_t<decltype(
get_categories()[0])>>{};
88 if (categories.size() != 0) {
89 result = *std::max_element(std::begin(categories), std::end(categories)) + 1;
97 auto result = std::vector<tl_output_t>{};
98 if (tl_tree_.HasLeafVector(node_id_)) {
99 result = tl_tree_.LeafVector(node_id_);
101 result.push_back(tl_tree_.LeafValue(node_id_));
109 treelite::Tree<tl_threshold_t, tl_output_t>
const& tl_tree_;
112 auto left_is_hot()
const
116 if (tl_tree_.CategoryListRightChild(node_id_)) { result =
true; }
118 auto tl_operator = tl_tree_.ComparisonOp(node_id_);
119 if (tl_operator == treelite::Operator::kLT || tl_operator == treelite::Operator::kLE) {
121 }
else if (tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kGE) {
124 throw traversal_exception(
"Unrecognized Treelite operator");
131 template <
typename tl_threshold_t,
typename tl_output_t>
133 :
public traversal_forest<treelite_traversal_node<tl_threshold_t, tl_output_t>> {
145 auto result = std::vector<node_uid_type>{};
146 result.reserve(tl_model.GetNumTree());
147 for (
auto i = std::size_t{}; i < tl_model.GetNumTree(); ++i) {
158 return node_type{tl_model_.trees[tree_id], node_id};
162 treelite::ModelPreset<tl_threshold_t, tl_output_t>
const& tl_model_;
165 template <
typename lambda_t>
169 [&lambda](
auto&& concrete_tl_model) {
170 std::for_each(std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), lambda);
175 template <
typename iter_t,
typename lambda_t>
176 void tree_transform(treelite::Model
const& tl_model, iter_t out_iter, lambda_t&& lambda)
179 [&lambda, out_iter](
auto&& concrete_tl_model) {
181 std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), out_iter, lambda);
186 template <
typename T,
typename lambda_t>
190 [&lambda, init](
auto&& concrete_tl_model) {
191 return std::accumulate(
192 std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), init, lambda);
197 template <forest_order order,
typename lambda_t>
201 [&lambda](
auto&& concrete_tl_model) {
207 template <forest_order order,
typename iter_t,
typename lambda_t>
208 void node_transform(treelite::Model
const& tl_model, iter_t output_iter, lambda_t&& lambda)
210 node_for_each<order>(
212 [&output_iter, &lambda](
auto&& tree_id,
auto&& node,
auto&& depth,
auto&& parent_index) {
213 *output_iter = lambda(tree_id, node, depth, parent_index);
218 template <forest_order order,
typename T,
typename lambda_t>
222 node_for_each<order>(
223 tl_model, [&result, &lambda](
auto&& tree_id,
auto&& node,
auto&& depth,
auto&& parent_index) {
224 result = lambda(result, tree_id, node, depth, parent_index);
void node_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite.hpp:198
void tree_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite.hpp:166
int TREELITE_NODE_ID_T
Definition: treelite.hpp:30
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:187
void node_transform(treelite::Model const &tl_model, iter_t output_iter, lambda_t &&lambda)
Definition: treelite.hpp:208
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:176
auto node_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:219
void transform(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
Definition: dbscan.hpp:29
Definition: traversal_forest.hpp:75
std::size_t tree_id_type
Definition: traversal_forest.hpp:78
typename node_type::id_type node_id_type
Definition: traversal_forest.hpp:77
std::pair< tree_id_type, node_id_type > node_uid_type
Definition: traversal_forest.hpp:79
treelite_traversal_node< tl_threshold_t, tl_output_t > node_type
Definition: traversal_forest.hpp:76
Definition: traversal_node.hpp:36
TREELITE_NODE_ID_T id_type
Definition: traversal_node.hpp:38
Definition: treelite.hpp:133
treelite_traversal_forest(treelite::ModelPreset< tl_threshold_t, tl_output_t > const &tl_model)
Definition: treelite.hpp:143
node_type get_node(tree_id_type tree_id, node_id_type node_id) const override
Definition: treelite.hpp:156
typename base_type::node_type node_type
Definition: treelite.hpp:138
typename base_type::tree_id_type tree_id_type
Definition: treelite.hpp:140
typename base_type::node_id_type node_id_type
Definition: treelite.hpp:139
typename base_type::node_uid_type node_uid_type
Definition: treelite.hpp:141
Definition: treelite.hpp:33
auto get_output() const
Definition: treelite.hpp:95
auto threshold() const
Definition: treelite.hpp:81
auto is_inclusive() const
Definition: treelite.hpp:68
bool is_leaf() const override
Definition: treelite.hpp:40
auto is_categorical() const
Definition: treelite.hpp:74
id_type hot_child() const override
Definition: treelite.hpp:42
auto get_feature() const
Definition: treelite.hpp:66
treelite_traversal_node(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, id_type node_id)
Definition: treelite.hpp:34
auto get_treelite_id() const
Definition: treelite.hpp:106
auto default_distant() const
Definition: treelite.hpp:64
auto max_num_categories() const
Definition: treelite.hpp:83
auto get_categories() const
Definition: treelite.hpp:79
id_type distant_child() const override
Definition: treelite.hpp:53