Loading [MathJax]/extensions/tex2jax.js
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
treelite.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2024-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
20 
21 #include <treelite/tree.h>
22 
23 #include <algorithm>
24 #include <numeric>
25 #include <vector>
26 
27 namespace ML {
28 namespace experimental {
29 namespace forest {
30 
31 using TREELITE_NODE_ID_T = int;
32 
33 template <typename tl_threshold_t, typename tl_output_t>
34 struct treelite_traversal_node : public traversal_node<TREELITE_NODE_ID_T> {
35  treelite_traversal_node(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree,
36  id_type node_id)
37  : traversal_node{}, tl_tree_{tl_tree}, node_id_{node_id}
38  {
39  }
40 
41  bool is_leaf() const override { return tl_tree_.IsLeaf(node_id_); }
42 
43  id_type hot_child() const override
44  {
45  auto result = id_type{};
46  if (left_is_hot()) {
47  result = tl_tree_.LeftChild(node_id_);
48  } else {
49  result = tl_tree_.RightChild(node_id_);
50  }
51  return result;
52  }
53 
54  id_type distant_child() const override
55  {
56  auto result = id_type{};
57  if (left_is_hot()) {
58  result = tl_tree_.RightChild(node_id_);
59  } else {
60  result = tl_tree_.LeftChild(node_id_);
61  }
62  return result;
63  }
64 
65  auto default_distant() const { return tl_tree_.DefaultChild(node_id_) == distant_child(); }
66 
67  auto get_feature() const { return tl_tree_.SplitIndex(node_id_); }
68 
69  auto is_inclusive() const
70  {
71  auto tl_operator = tl_tree_.ComparisonOp(node_id_);
72  return tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kLE;
73  }
74 
75  auto is_categorical() const
76  {
77  return tl_tree_.NodeType(node_id_) == treelite::TreeNodeType::kCategoricalTestNode;
78  }
79 
80  auto get_categories() const { return tl_tree_.CategoryList(node_id_); }
81 
82  auto threshold() const { return tl_tree_.Threshold(node_id_); }
83 
84  auto max_num_categories() const
85  {
86  auto result = std::remove_const_t<std::remove_reference_t<decltype(get_categories()[0])>>{};
87  if (is_categorical()) {
88  auto categories = get_categories();
89  if (categories.size() != 0) {
90  result = *std::max_element(std::begin(categories), std::end(categories)) + 1;
91  }
92  }
93  return result;
94  }
95 
96  auto get_output() const
97  {
98  auto result = std::vector<tl_output_t>{};
99  if (tl_tree_.HasLeafVector(node_id_)) {
100  result = tl_tree_.LeafVector(node_id_);
101  } else {
102  result.push_back(tl_tree_.LeafValue(node_id_));
103  }
104  return result;
105  }
106 
107  auto get_treelite_id() const { return node_id_; }
108 
109  private:
110  treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree_;
111  id_type node_id_;
112 
113  auto left_is_hot() const
114  {
115  auto result = false;
116  if (is_categorical()) {
117  if (tl_tree_.CategoryListRightChild(node_id_)) { result = true; }
118  } else {
119  auto tl_operator = tl_tree_.ComparisonOp(node_id_);
120  if (tl_operator == treelite::Operator::kLT || tl_operator == treelite::Operator::kLE) {
121  result = false;
122  } else if (tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kGE) {
123  result = true;
124  } else {
125  throw traversal_exception("Unrecognized Treelite operator");
126  }
127  }
128  return result;
129  }
130 };
131 
132 template <typename tl_threshold_t, typename tl_output_t>
134  : public traversal_forest<treelite_traversal_node<tl_threshold_t, tl_output_t>> {
135  private:
137 
138  public:
139  using node_type = typename base_type::node_type;
143 
144  treelite_traversal_forest(treelite::ModelPreset<tl_threshold_t, tl_output_t> const& tl_model)
145  : traversal_forest<treelite_traversal_node<tl_threshold_t, tl_output_t>>{[&tl_model]() {
146  auto result = std::vector<node_uid_type>{};
147  result.reserve(tl_model.GetNumTree());
148  for (auto i = std::size_t{}; i < tl_model.GetNumTree(); ++i) {
149  result.push_back(std::make_pair(i, TREELITE_NODE_ID_T{}));
150  }
151  return result;
152  }()},
153  tl_model_{tl_model}
154  {
155  }
156 
157  node_type get_node(tree_id_type tree_id, node_id_type node_id) const override
158  {
159  return node_type{tl_model_.trees[tree_id], node_id};
160  }
161 
162  private:
163  treelite::ModelPreset<tl_threshold_t, tl_output_t> const& tl_model_;
164 };
165 
166 template <typename lambda_t>
167 void tree_for_each(treelite::Model const& tl_model, lambda_t&& lambda)
168 {
169  std::visit(
170  [&lambda](auto&& concrete_tl_model) {
171  std::for_each(std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), lambda);
172  },
173  tl_model.variant_);
174 }
175 
176 template <typename iter_t, typename lambda_t>
177 void tree_transform(treelite::Model const& tl_model, iter_t out_iter, lambda_t&& lambda)
178 {
179  std::visit(
180  [&lambda, out_iter](auto&& concrete_tl_model) {
182  std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), out_iter, lambda);
183  },
184  tl_model.variant_);
185 }
186 
187 template <typename T, typename lambda_t>
188 auto tree_accumulate(treelite::Model const& tl_model, T init, lambda_t&& lambda)
189 {
190  return std::visit(
191  [&lambda, init](auto&& concrete_tl_model) {
192  return std::accumulate(
193  std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), init, lambda);
194  },
195  tl_model.variant_);
196 }
197 
198 template <forest_order order, typename lambda_t>
199 void node_for_each(treelite::Model const& tl_model, lambda_t&& lambda)
200 {
201  std::visit(
202  [&lambda](auto&& concrete_tl_model) {
203  treelite_traversal_forest{concrete_tl_model}.template for_each<order>(lambda);
204  },
205  tl_model.variant_);
206 }
207 
208 template <forest_order order, typename iter_t, typename lambda_t>
209 void node_transform(treelite::Model const& tl_model, iter_t output_iter, lambda_t&& lambda)
210 {
211  node_for_each<order>(
212  tl_model,
213  [&output_iter, &lambda](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
214  *output_iter = lambda(tree_id, node, depth, parent_index);
215  ++output_iter;
216  });
217 }
218 
219 template <forest_order order, typename T, typename lambda_t>
220 auto node_accumulate(treelite::Model const& tl_model, T init, lambda_t&& lambda)
221 {
222  auto result = init;
223  node_for_each<order>(
224  tl_model, [&result, &lambda](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
225  result = lambda(result, tree_id, node, depth, parent_index);
226  });
227  return result;
228 }
229 
230 } // namespace forest
231 } // namespace experimental
232 } // namespace ML
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:177
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:188
int TREELITE_NODE_ID_T
Definition: treelite.hpp:31
void node_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite.hpp:199
void tree_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite.hpp:167
auto node_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:220
void node_transform(treelite::Model const &tl_model, iter_t output_iter, lambda_t &&lambda)
Definition: treelite.hpp:209
void transform(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
Definition: dbscan.hpp:30
Definition: traversal_forest.hpp:76
treelite_traversal_node< tl_threshold_t, tl_output_t > node_type
Definition: traversal_forest.hpp:77
std::pair< tree_id_type, node_id_type > node_uid_type
Definition: traversal_forest.hpp:80
Definition: traversal_node.hpp:37
TREELITE_NODE_ID_T id_type
Definition: traversal_node.hpp:39
typename base_type::node_id_type node_id_type
Definition: treelite.hpp:140
typename base_type::tree_id_type tree_id_type
Definition: treelite.hpp:141
typename base_type::node_type node_type
Definition: treelite.hpp:139
node_type get_node(tree_id_type tree_id, node_id_type node_id) const override
Definition: treelite.hpp:157
treelite_traversal_forest(treelite::ModelPreset< tl_threshold_t, tl_output_t > const &tl_model)
Definition: treelite.hpp:144
typename base_type::node_uid_type node_uid_type
Definition: treelite.hpp:142
auto default_distant() const
Definition: treelite.hpp:65
auto is_categorical() const
Definition: treelite.hpp:75
auto get_feature() const
Definition: treelite.hpp:67
auto threshold() const
Definition: treelite.hpp:82
id_type distant_child() const override
Definition: treelite.hpp:54
id_type hot_child() const override
Definition: treelite.hpp:43
auto max_num_categories() const
Definition: treelite.hpp:84
auto get_treelite_id() const
Definition: treelite.hpp:107
treelite_traversal_node(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, id_type node_id)
Definition: treelite.hpp:35
auto get_categories() const
Definition: treelite.hpp:80
bool is_leaf() const override
Definition: treelite.hpp:41
auto get_output() const
Definition: treelite.hpp:96
auto is_inclusive() const
Definition: treelite.hpp:69