treelite.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
9 
10 #include <treelite/tree.h>
11 
12 #include <algorithm>
13 #include <numeric>
14 #include <vector>
15 
16 namespace ML {
17 namespace forest {
18 
19 using TREELITE_NODE_ID_T = int;
20 
21 template <typename tl_threshold_t, typename tl_output_t>
22 struct treelite_traversal_node : public traversal_node<TREELITE_NODE_ID_T> {
23  treelite_traversal_node(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree,
24  id_type node_id)
25  : traversal_node{}, tl_tree_{tl_tree}, node_id_{node_id}
26  {
27  }
28 
29  bool is_leaf() const override { return tl_tree_.IsLeaf(node_id_); }
30 
31  id_type hot_child() const override
32  {
33  auto result = id_type{};
34  if (left_is_hot()) {
35  result = tl_tree_.LeftChild(node_id_);
36  } else {
37  result = tl_tree_.RightChild(node_id_);
38  }
39  return result;
40  }
41 
42  id_type distant_child() const override
43  {
44  auto result = id_type{};
45  if (left_is_hot()) {
46  result = tl_tree_.RightChild(node_id_);
47  } else {
48  result = tl_tree_.LeftChild(node_id_);
49  }
50  return result;
51  }
52 
53  auto default_distant() const { return tl_tree_.DefaultChild(node_id_) == distant_child(); }
54 
55  auto get_feature() const { return tl_tree_.SplitIndex(node_id_); }
56 
57  auto is_inclusive() const
58  {
59  auto tl_operator = tl_tree_.ComparisonOp(node_id_);
60  return tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kLE;
61  }
62 
63  auto is_categorical() const
64  {
65  return tl_tree_.NodeType(node_id_) == treelite::TreeNodeType::kCategoricalTestNode;
66  }
67 
68  auto get_categories() const { return tl_tree_.CategoryList(node_id_); }
69 
70  auto threshold() const { return tl_tree_.Threshold(node_id_); }
71 
72 // Temporarily disable free-nonheap-object warning to work around spurious warnings emitted by
73 // GCC 14.x. See https://github.com/rapidsai/cuml/pull/7471#issuecomment-3525796585 for more
74 // details.
75 // TODO(hcho3): Remove this pragma once GCC is upgraded to 15.
76 #pragma GCC diagnostic push
77 #pragma GCC diagnostic ignored "-Wfree-nonheap-object"
78  auto max_num_categories() const
79  {
80  auto result = std::remove_const_t<std::remove_reference_t<decltype(get_categories()[0])>>{};
81  if (is_categorical()) {
82  auto categories = get_categories();
83  if (categories.size() != 0) {
84  result = *std::max_element(std::begin(categories), std::end(categories)) + 1;
85  }
86  }
87  return result;
88  }
89 #pragma GCC diagnostic pop
90 
91  auto get_output() const
92  {
93  auto result = std::vector<tl_output_t>{};
94  if (tl_tree_.HasLeafVector(node_id_)) {
95  result = tl_tree_.LeafVector(node_id_);
96  } else {
97  result.push_back(tl_tree_.LeafValue(node_id_));
98  }
99  return result;
100  }
101 
102  auto get_treelite_id() const { return node_id_; }
103 
104  private:
105  treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree_;
106  id_type node_id_;
107 
108  auto left_is_hot() const
109  {
110  auto result = false;
111  if (is_categorical()) {
112  if (tl_tree_.CategoryListRightChild(node_id_)) { result = true; }
113  } else {
114  auto tl_operator = tl_tree_.ComparisonOp(node_id_);
115  if (tl_operator == treelite::Operator::kLT || tl_operator == treelite::Operator::kLE) {
116  result = false;
117  } else if (tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kGE) {
118  result = true;
119  } else {
120  throw traversal_exception("Unrecognized Treelite operator");
121  }
122  }
123  return result;
124  }
125 };
126 
127 template <typename tl_threshold_t, typename tl_output_t>
129  : public traversal_forest<treelite_traversal_node<tl_threshold_t, tl_output_t>> {
130  private:
132 
133  public:
134  using node_type = typename base_type::node_type;
138 
139  treelite_traversal_forest(treelite::ModelPreset<tl_threshold_t, tl_output_t> const& tl_model)
140  : traversal_forest<treelite_traversal_node<tl_threshold_t, tl_output_t>>{[&tl_model]() {
141  auto result = std::vector<node_uid_type>{};
142  result.reserve(tl_model.GetNumTree());
143  for (auto i = std::size_t{}; i < tl_model.GetNumTree(); ++i) {
144  result.push_back(std::make_pair(i, TREELITE_NODE_ID_T{}));
145  }
146  return result;
147  }()},
148  tl_model_{tl_model}
149  {
150  }
151 
152  node_type get_node(tree_id_type tree_id, node_id_type node_id) const override
153  {
154  return node_type{tl_model_.trees[tree_id], node_id};
155  }
156 
157  private:
158  treelite::ModelPreset<tl_threshold_t, tl_output_t> const& tl_model_;
159 };
160 
161 template <typename lambda_t>
162 void tree_for_each(treelite::Model const& tl_model, lambda_t&& lambda)
163 {
164  std::visit(
165  [&lambda](auto&& concrete_tl_model) {
166  std::for_each(std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), lambda);
167  },
168  tl_model.variant_);
169 }
170 
171 template <typename iter_t, typename lambda_t>
172 void tree_transform(treelite::Model const& tl_model, iter_t out_iter, lambda_t&& lambda)
173 {
174  std::visit(
175  [&lambda, out_iter](auto&& concrete_tl_model) {
177  std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), out_iter, lambda);
178  },
179  tl_model.variant_);
180 }
181 
182 template <typename T, typename lambda_t>
183 auto tree_accumulate(treelite::Model const& tl_model, T init, lambda_t&& lambda)
184 {
185  return std::visit(
186  [&lambda, init](auto&& concrete_tl_model) {
187  return std::accumulate(
188  std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), init, lambda);
189  },
190  tl_model.variant_);
191 }
192 
193 template <forest_order order, typename lambda_t>
194 void node_for_each(treelite::Model const& tl_model, lambda_t&& lambda)
195 {
196  std::visit(
197  [&lambda](auto&& concrete_tl_model) {
198  treelite_traversal_forest{concrete_tl_model}.template for_each<order>(lambda);
199  },
200  tl_model.variant_);
201 }
202 
203 template <forest_order order, typename iter_t, typename lambda_t>
204 void node_transform(treelite::Model const& tl_model, iter_t output_iter, lambda_t&& lambda)
205 {
206  node_for_each<order>(
207  tl_model,
208  [&output_iter, &lambda](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
209  *output_iter = lambda(tree_id, node, depth, parent_index);
210  ++output_iter;
211  });
212 }
213 
214 template <forest_order order, typename T, typename lambda_t>
215 auto node_accumulate(treelite::Model const& tl_model, T init, lambda_t&& lambda)
216 {
217  auto result = init;
218  node_for_each<order>(
219  tl_model, [&result, &lambda](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
220  result = lambda(result, tree_id, node, depth, parent_index);
221  });
222  return result;
223 }
224 
225 } // namespace forest
226 } // namespace ML
void node_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite.hpp:194
void tree_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite.hpp:162
int TREELITE_NODE_ID_T
Definition: treelite.hpp:19
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:183
void node_transform(treelite::Model const &tl_model, iter_t output_iter, lambda_t &&lambda)
Definition: treelite.hpp:204
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:172
auto node_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:215
void transform(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
Definition: dbscan.hpp:18
Definition: traversal_forest.hpp:64
typename node_type::id_type node_id_type
Definition: traversal_forest.hpp:66
std::pair< tree_id_type, node_id_type > node_uid_type
Definition: traversal_forest.hpp:68
treelite_traversal_node< tl_threshold_t, tl_output_t > node_type
Definition: traversal_forest.hpp:65
Definition: traversal_node.hpp:25
TREELITE_NODE_ID_T id_type
Definition: traversal_node.hpp:27
Definition: treelite.hpp:129
treelite_traversal_forest(treelite::ModelPreset< tl_threshold_t, tl_output_t > const &tl_model)
Definition: treelite.hpp:139
node_type get_node(tree_id_type tree_id, node_id_type node_id) const override
Definition: treelite.hpp:152
typename base_type::node_type node_type
Definition: treelite.hpp:134
typename base_type::tree_id_type tree_id_type
Definition: treelite.hpp:136
typename base_type::node_id_type node_id_type
Definition: treelite.hpp:135
typename base_type::node_uid_type node_uid_type
Definition: treelite.hpp:137
Definition: treelite.hpp:22
auto get_output() const
Definition: treelite.hpp:91
auto threshold() const
Definition: treelite.hpp:70
auto is_inclusive() const
Definition: treelite.hpp:57
bool is_leaf() const override
Definition: treelite.hpp:29
auto is_categorical() const
Definition: treelite.hpp:63
id_type hot_child() const override
Definition: treelite.hpp:31
auto get_feature() const
Definition: treelite.hpp:55
treelite_traversal_node(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, id_type node_id)
Definition: treelite.hpp:23
auto get_treelite_id() const
Definition: treelite.hpp:102
auto default_distant() const
Definition: treelite.hpp:53
auto max_num_categories() const
Definition: treelite.hpp:78
auto get_categories() const
Definition: treelite.hpp:68
id_type distant_child() const override
Definition: treelite.hpp:42