treelite.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2024-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
20 
21 #include <treelite/tree.h>
22 
23 #include <algorithm>
24 #include <numeric>
25 #include <vector>
26 
27 namespace ML {
28 namespace forest {
29 
30 using TREELITE_NODE_ID_T = int;
31 
32 template <typename tl_threshold_t, typename tl_output_t>
33 struct treelite_traversal_node : public traversal_node<TREELITE_NODE_ID_T> {
34  treelite_traversal_node(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree,
35  id_type node_id)
36  : traversal_node{}, tl_tree_{tl_tree}, node_id_{node_id}
37  {
38  }
39 
40  bool is_leaf() const override { return tl_tree_.IsLeaf(node_id_); }
41 
42  id_type hot_child() const override
43  {
44  auto result = id_type{};
45  if (left_is_hot()) {
46  result = tl_tree_.LeftChild(node_id_);
47  } else {
48  result = tl_tree_.RightChild(node_id_);
49  }
50  return result;
51  }
52 
53  id_type distant_child() const override
54  {
55  auto result = id_type{};
56  if (left_is_hot()) {
57  result = tl_tree_.RightChild(node_id_);
58  } else {
59  result = tl_tree_.LeftChild(node_id_);
60  }
61  return result;
62  }
63 
64  auto default_distant() const { return tl_tree_.DefaultChild(node_id_) == distant_child(); }
65 
66  auto get_feature() const { return tl_tree_.SplitIndex(node_id_); }
67 
68  auto is_inclusive() const
69  {
70  auto tl_operator = tl_tree_.ComparisonOp(node_id_);
71  return tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kLE;
72  }
73 
74  auto is_categorical() const
75  {
76  return tl_tree_.NodeType(node_id_) == treelite::TreeNodeType::kCategoricalTestNode;
77  }
78 
79  auto get_categories() const { return tl_tree_.CategoryList(node_id_); }
80 
81  auto threshold() const { return tl_tree_.Threshold(node_id_); }
82 
83  auto max_num_categories() const
84  {
85  auto result = std::remove_const_t<std::remove_reference_t<decltype(get_categories()[0])>>{};
86  if (is_categorical()) {
87  auto categories = get_categories();
88  if (categories.size() != 0) {
89  result = *std::max_element(std::begin(categories), std::end(categories)) + 1;
90  }
91  }
92  return result;
93  }
94 
95  auto get_output() const
96  {
97  auto result = std::vector<tl_output_t>{};
98  if (tl_tree_.HasLeafVector(node_id_)) {
99  result = tl_tree_.LeafVector(node_id_);
100  } else {
101  result.push_back(tl_tree_.LeafValue(node_id_));
102  }
103  return result;
104  }
105 
106  auto get_treelite_id() const { return node_id_; }
107 
108  private:
109  treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree_;
110  id_type node_id_;
111 
112  auto left_is_hot() const
113  {
114  auto result = false;
115  if (is_categorical()) {
116  if (tl_tree_.CategoryListRightChild(node_id_)) { result = true; }
117  } else {
118  auto tl_operator = tl_tree_.ComparisonOp(node_id_);
119  if (tl_operator == treelite::Operator::kLT || tl_operator == treelite::Operator::kLE) {
120  result = false;
121  } else if (tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kGE) {
122  result = true;
123  } else {
124  throw traversal_exception("Unrecognized Treelite operator");
125  }
126  }
127  return result;
128  }
129 };
130 
131 template <typename tl_threshold_t, typename tl_output_t>
133  : public traversal_forest<treelite_traversal_node<tl_threshold_t, tl_output_t>> {
134  private:
136 
137  public:
138  using node_type = typename base_type::node_type;
142 
143  treelite_traversal_forest(treelite::ModelPreset<tl_threshold_t, tl_output_t> const& tl_model)
144  : traversal_forest<treelite_traversal_node<tl_threshold_t, tl_output_t>>{[&tl_model]() {
145  auto result = std::vector<node_uid_type>{};
146  result.reserve(tl_model.GetNumTree());
147  for (auto i = std::size_t{}; i < tl_model.GetNumTree(); ++i) {
148  result.push_back(std::make_pair(i, TREELITE_NODE_ID_T{}));
149  }
150  return result;
151  }()},
152  tl_model_{tl_model}
153  {
154  }
155 
156  node_type get_node(tree_id_type tree_id, node_id_type node_id) const override
157  {
158  return node_type{tl_model_.trees[tree_id], node_id};
159  }
160 
161  private:
162  treelite::ModelPreset<tl_threshold_t, tl_output_t> const& tl_model_;
163 };
164 
165 template <typename lambda_t>
166 void tree_for_each(treelite::Model const& tl_model, lambda_t&& lambda)
167 {
168  std::visit(
169  [&lambda](auto&& concrete_tl_model) {
170  std::for_each(std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), lambda);
171  },
172  tl_model.variant_);
173 }
174 
175 template <typename iter_t, typename lambda_t>
176 void tree_transform(treelite::Model const& tl_model, iter_t out_iter, lambda_t&& lambda)
177 {
178  std::visit(
179  [&lambda, out_iter](auto&& concrete_tl_model) {
181  std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), out_iter, lambda);
182  },
183  tl_model.variant_);
184 }
185 
186 template <typename T, typename lambda_t>
187 auto tree_accumulate(treelite::Model const& tl_model, T init, lambda_t&& lambda)
188 {
189  return std::visit(
190  [&lambda, init](auto&& concrete_tl_model) {
191  return std::accumulate(
192  std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), init, lambda);
193  },
194  tl_model.variant_);
195 }
196 
197 template <forest_order order, typename lambda_t>
198 void node_for_each(treelite::Model const& tl_model, lambda_t&& lambda)
199 {
200  std::visit(
201  [&lambda](auto&& concrete_tl_model) {
202  treelite_traversal_forest{concrete_tl_model}.template for_each<order>(lambda);
203  },
204  tl_model.variant_);
205 }
206 
207 template <forest_order order, typename iter_t, typename lambda_t>
208 void node_transform(treelite::Model const& tl_model, iter_t output_iter, lambda_t&& lambda)
209 {
210  node_for_each<order>(
211  tl_model,
212  [&output_iter, &lambda](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
213  *output_iter = lambda(tree_id, node, depth, parent_index);
214  ++output_iter;
215  });
216 }
217 
218 template <forest_order order, typename T, typename lambda_t>
219 auto node_accumulate(treelite::Model const& tl_model, T init, lambda_t&& lambda)
220 {
221  auto result = init;
222  node_for_each<order>(
223  tl_model, [&result, &lambda](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
224  result = lambda(result, tree_id, node, depth, parent_index);
225  });
226  return result;
227 }
228 
229 } // namespace forest
230 } // namespace ML
void node_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite.hpp:198
void tree_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite.hpp:166
int TREELITE_NODE_ID_T
Definition: treelite.hpp:30
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:187
void node_transform(treelite::Model const &tl_model, iter_t output_iter, lambda_t &&lambda)
Definition: treelite.hpp:208
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:176
auto node_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:219
void transform(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
Definition: dbscan.hpp:29
Definition: traversal_forest.hpp:75
typename node_type::id_type node_id_type
Definition: traversal_forest.hpp:77
std::pair< tree_id_type, node_id_type > node_uid_type
Definition: traversal_forest.hpp:79
treelite_traversal_node< tl_threshold_t, tl_output_t > node_type
Definition: traversal_forest.hpp:76
Definition: traversal_node.hpp:36
TREELITE_NODE_ID_T id_type
Definition: traversal_node.hpp:38
Definition: treelite.hpp:133
treelite_traversal_forest(treelite::ModelPreset< tl_threshold_t, tl_output_t > const &tl_model)
Definition: treelite.hpp:143
node_type get_node(tree_id_type tree_id, node_id_type node_id) const override
Definition: treelite.hpp:156
typename base_type::node_type node_type
Definition: treelite.hpp:138
typename base_type::tree_id_type tree_id_type
Definition: treelite.hpp:140
typename base_type::node_id_type node_id_type
Definition: treelite.hpp:139
typename base_type::node_uid_type node_uid_type
Definition: treelite.hpp:141
Definition: treelite.hpp:33
auto get_output() const
Definition: treelite.hpp:95
auto threshold() const
Definition: treelite.hpp:81
auto is_inclusive() const
Definition: treelite.hpp:68
bool is_leaf() const override
Definition: treelite.hpp:40
auto is_categorical() const
Definition: treelite.hpp:74
id_type hot_child() const override
Definition: treelite.hpp:42
auto get_feature() const
Definition: treelite.hpp:66
treelite_traversal_node(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, id_type node_id)
Definition: treelite.hpp:34
auto get_treelite_id() const
Definition: treelite.hpp:106
auto default_distant() const
Definition: treelite.hpp:64
auto max_num_categories() const
Definition: treelite.hpp:83
auto get_categories() const
Definition: treelite.hpp:79
id_type distant_child() const override
Definition: treelite.hpp:53