21 #include <treelite/tree.h>
28 namespace experimental {
33 template <
typename tl_threshold_t,
typename tl_output_t>
41 bool is_leaf()
const override {
return tl_tree_.IsLeaf(node_id_); }
47 result = tl_tree_.LeftChild(node_id_);
49 result = tl_tree_.RightChild(node_id_);
58 result = tl_tree_.RightChild(node_id_);
60 result = tl_tree_.LeftChild(node_id_);
67 auto get_feature()
const {
return tl_tree_.SplitIndex(node_id_); }
71 auto tl_operator = tl_tree_.ComparisonOp(node_id_);
72 return tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kLE;
77 return tl_tree_.NodeType(node_id_) == treelite::TreeNodeType::kCategoricalTestNode;
82 auto threshold()
const {
return tl_tree_.Threshold(node_id_); }
86 auto result = std::remove_const_t<std::remove_reference_t<decltype(
get_categories()[0])>>{};
89 if (categories.size() != 0) {
90 result = *std::max_element(std::begin(categories), std::end(categories)) + 1;
98 auto result = std::vector<tl_output_t>{};
99 if (tl_tree_.HasLeafVector(node_id_)) {
100 result = tl_tree_.LeafVector(node_id_);
102 result.push_back(tl_tree_.LeafValue(node_id_));
110 treelite::Tree<tl_threshold_t, tl_output_t>
const& tl_tree_;
113 auto left_is_hot()
const
117 if (tl_tree_.CategoryListRightChild(node_id_)) { result =
true; }
119 auto tl_operator = tl_tree_.ComparisonOp(node_id_);
120 if (tl_operator == treelite::Operator::kLT || tl_operator == treelite::Operator::kLE) {
122 }
else if (tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kGE) {
125 throw traversal_exception(
"Unrecognized Treelite operator");
132 template <
typename tl_threshold_t,
typename tl_output_t>
134 :
public traversal_forest<treelite_traversal_node<tl_threshold_t, tl_output_t>> {
146 auto result = std::vector<node_uid_type>{};
147 result.reserve(tl_model.GetNumTree());
148 for (
auto i = std::size_t{}; i < tl_model.GetNumTree(); ++i) {
159 return node_type{tl_model_.trees[tree_id], node_id};
163 treelite::ModelPreset<tl_threshold_t, tl_output_t>
const& tl_model_;
166 template <
typename lambda_t>
170 [&lambda](
auto&& concrete_tl_model) {
171 std::for_each(std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), lambda);
176 template <
typename iter_t,
typename lambda_t>
177 void tree_transform(treelite::Model
const& tl_model, iter_t out_iter, lambda_t&& lambda)
180 [&lambda, out_iter](
auto&& concrete_tl_model) {
182 std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), out_iter, lambda);
187 template <
typename T,
typename lambda_t>
191 [&lambda, init](
auto&& concrete_tl_model) {
192 return std::accumulate(
193 std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), init, lambda);
198 template <forest_order order,
typename lambda_t>
202 [&lambda](
auto&& concrete_tl_model) {
208 template <forest_order order,
typename iter_t,
typename lambda_t>
209 void node_transform(treelite::Model
const& tl_model, iter_t output_iter, lambda_t&& lambda)
211 node_for_each<order>(
213 [&output_iter, &lambda](
auto&& tree_id,
auto&& node,
auto&& depth,
auto&& parent_index) {
214 *output_iter = lambda(tree_id, node, depth, parent_index);
219 template <forest_order order,
typename T,
typename lambda_t>
223 node_for_each<order>(
224 tl_model, [&result, &lambda](
auto&& tree_id,
auto&& node,
auto&& depth,
auto&& parent_index) {
225 result = lambda(result, tree_id, node, depth, parent_index);
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:177
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:188
int TREELITE_NODE_ID_T
Definition: treelite.hpp:31
void node_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite.hpp:199
void tree_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite.hpp:167
auto node_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:220
void node_transform(treelite::Model const &tl_model, iter_t output_iter, lambda_t &&lambda)
Definition: treelite.hpp:209
void transform(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
Definition: dbscan.hpp:30
Definition: traversal_forest.hpp:76
treelite_traversal_node< tl_threshold_t, tl_output_t > node_type
Definition: traversal_forest.hpp:77
typename node_type::id_type node_id_type
Definition: traversal_forest.hpp:78
std::pair< tree_id_type, node_id_type > node_uid_type
Definition: traversal_forest.hpp:80
std::size_t tree_id_type
Definition: traversal_forest.hpp:79
Definition: traversal_node.hpp:37
TREELITE_NODE_ID_T id_type
Definition: traversal_node.hpp:39
Definition: treelite.hpp:134
typename base_type::node_id_type node_id_type
Definition: treelite.hpp:140
typename base_type::tree_id_type tree_id_type
Definition: treelite.hpp:141
typename base_type::node_type node_type
Definition: treelite.hpp:139
node_type get_node(tree_id_type tree_id, node_id_type node_id) const override
Definition: treelite.hpp:157
treelite_traversal_forest(treelite::ModelPreset< tl_threshold_t, tl_output_t > const &tl_model)
Definition: treelite.hpp:144
typename base_type::node_uid_type node_uid_type
Definition: treelite.hpp:142
Definition: treelite.hpp:34
auto default_distant() const
Definition: treelite.hpp:65
auto is_categorical() const
Definition: treelite.hpp:75
auto get_feature() const
Definition: treelite.hpp:67
auto threshold() const
Definition: treelite.hpp:82
id_type distant_child() const override
Definition: treelite.hpp:54
id_type hot_child() const override
Definition: treelite.hpp:43
auto max_num_categories() const
Definition: treelite.hpp:84
auto get_treelite_id() const
Definition: treelite.hpp:107
treelite_traversal_node(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, id_type node_id)
Definition: treelite.hpp:35
auto get_categories() const
Definition: treelite.hpp:80
bool is_leaf() const override
Definition: treelite.hpp:41
auto get_output() const
Definition: treelite.hpp:96
auto is_inclusive() const
Definition: treelite.hpp:69