Loading [MathJax]/extensions/tex2jax.js
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
treelite_importer.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
26 
27 #include <treelite/c_api.h>
28 #include <treelite/enum/task_type.h>
29 #include <treelite/enum/tree_node_type.h>
30 #include <treelite/enum/typeinfo.h>
31 #include <treelite/tree.h>
32 
33 #include <cmath>
34 #include <variant>
35 
36 namespace ML {
37 namespace experimental {
38 namespace fil {
39 
40 namespace detail {
41 
45  double constant = 1.0;
46 };
47 } // namespace detail
48 
54 template <tree_layout layout>
56  auto static constexpr const traversal_order = []() constexpr {
57  if constexpr (layout == tree_layout::depth_first) {
59  } else if constexpr (layout == tree_layout::breadth_first) {
61  } else if constexpr (layout == tree_layout::layered_children_together) {
63  } else {
64  static_assert(layout == tree_layout::depth_first,
65  "Layout not yet implemented in treelite importer for FIL");
66  }
67  }();
68 
69  auto get_node_count(treelite::Model const& tl_model)
70  {
72  tl_model, index_type{}, [](auto&& count, auto&& tree) { return count + tree.num_nodes; });
73  }
74 
75  /* Return vector of offsets between each node and its most distant child */
76  auto get_offsets(treelite::Model const& tl_model)
77  {
78  auto node_count = get_node_count(tl_model);
79  auto result = std::vector<index_type>(node_count);
80  auto parent_indexes = std::vector<index_type>{};
81  parent_indexes.reserve(node_count);
82  ML::experimental::forest::node_transform<traversal_order>(
83  tl_model,
84  std::back_inserter(parent_indexes),
85  [](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) { return parent_index; });
86  for (auto i = std::size_t{}; i < node_count; ++i) {
87  result[parent_indexes[i]] = i - parent_indexes[i];
88  }
89  return result;
90  }
91 
92  auto num_trees(treelite::Model const& tl_model)
93  {
94  auto result = index_type{};
95  std::visit([&result](auto&& concrete_tl_model) { result = concrete_tl_model.trees.size(); },
96  tl_model.variant_);
97  return result;
98  }
99 
100  auto get_tree_sizes(treelite::Model const& tl_model)
101  {
102  auto result = std::vector<index_type>{};
104  tl_model, std::back_inserter(result), [](auto&& tree) { return tree.num_nodes; });
105  return result;
106  }
107 
108  auto get_num_class(treelite::Model const& tl_model)
109  {
110  return static_cast<index_type>(tl_model.num_class[0]);
111  }
112 
113  auto get_num_feature(treelite::Model const& tl_model)
114  {
115  return static_cast<index_type>(tl_model.num_feature);
116  }
117 
118  auto get_max_num_categories(treelite::Model const& tl_model)
119  {
120  return ML::experimental::forest::node_accumulate<traversal_order>(
121  tl_model,
122  index_type{},
123  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
124  return std::max(cur_accum, static_cast<index_type>(node.max_num_categories()));
125  });
126  }
127 
128  auto get_num_categorical_nodes(treelite::Model const& tl_model)
129  {
130  return ML::experimental::forest::node_accumulate<traversal_order>(
131  tl_model,
132  index_type{},
133  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
134  return cur_accum + static_cast<index_type>(node.is_categorical());
135  });
136  }
137 
138  auto get_num_leaf_vector_nodes(treelite::Model const& tl_model)
139  {
140  return ML::experimental::forest::node_accumulate<traversal_order>(
141  tl_model,
142  index_type{},
143  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
144  auto accum = cur_accum;
145  if (node.is_leaf() && node.get_output().size() > 1) { ++accum; }
146  return accum;
147  });
148  }
149 
150  auto get_average_factor(treelite::Model const& tl_model)
151  {
152  auto result = double{};
153  if (tl_model.average_tree_output) {
154  if (tl_model.task_type == treelite::TaskType::kMultiClf &&
155  tl_model.leaf_vector_shape[1] == 1) { // grove-per-class
156  result = num_trees(tl_model) / tl_model.num_class[0];
157  } else {
158  result = num_trees(tl_model);
159  }
160  } else {
161  result = 1.0;
162  }
163  return result;
164  }
165 
166  auto get_bias(treelite::Model const& tl_model)
167  {
168  return static_cast<double>(tl_model.base_scores[0]);
169  }
170 
171  auto get_postproc_params(treelite::Model const& tl_model)
172  {
173  auto result = detail::postproc_params_t{};
174  auto tl_pred_transform = tl_model.postprocessor;
175  if (tl_pred_transform == std::string{"identity"} ||
176  tl_pred_transform == std::string{"identity_multiclass"}) {
177  result.element = element_op::disable;
178  result.row = row_op::disable;
179  } else if (tl_pred_transform == std::string{"signed_square"}) {
180  result.element = element_op::signed_square;
181  } else if (tl_pred_transform == std::string{"hinge"}) {
182  result.element = element_op::hinge;
183  } else if (tl_pred_transform == std::string{"sigmoid"}) {
184  result.constant = tl_model.sigmoid_alpha;
185  result.element = element_op::sigmoid;
186  } else if (tl_pred_transform == std::string{"exponential"}) {
187  result.element = element_op::exponential;
188  } else if (tl_pred_transform == std::string{"exponential_standard_ratio"}) {
189  result.constant = -tl_model.ratio_c / std::log(2);
190  result.element = element_op::exponential;
191  } else if (tl_pred_transform == std::string{"logarithm_one_plus_exp"}) {
192  result.element = element_op::logarithm_one_plus_exp;
193  } else if (tl_pred_transform == std::string{"max_index"}) {
194  result.row = row_op::max_index;
195  } else if (tl_pred_transform == std::string{"softmax"}) {
196  result.row = row_op::softmax;
197  } else if (tl_pred_transform == std::string{"multiclass_ova"}) {
198  result.constant = tl_model.sigmoid_alpha;
199  result.element = element_op::sigmoid;
200  } else {
201  throw model_import_error{"Unrecognized Treelite pred_transform string"};
202  }
203  return result;
204  }
205 
206  auto uses_double_thresholds(treelite::Model const& tl_model)
207  {
208  auto result = false;
209  switch (tl_model.GetThresholdType()) {
210  case treelite::TypeInfo::kFloat64: result = true; break;
211  case treelite::TypeInfo::kFloat32: result = false; break;
212  default: throw model_import_error("Unrecognized Treelite threshold type");
213  }
214  return result;
215  }
216 
217  auto uses_double_outputs(treelite::Model const& tl_model)
218  {
219  auto result = false;
220  switch (tl_model.GetThresholdType()) {
221  case treelite::TypeInfo::kFloat64: result = true; break;
222  case treelite::TypeInfo::kFloat32: result = false; break;
223  case treelite::TypeInfo::kUInt32: result = false; break;
224  default: throw model_import_error("Unrecognized Treelite threshold type");
225  }
226  return result;
227  }
228 
229  auto uses_integer_outputs(treelite::Model const& tl_model)
230  {
231  auto result = false;
232  switch (tl_model.GetThresholdType()) {
233  case treelite::TypeInfo::kFloat64: result = false; break;
234  case treelite::TypeInfo::kFloat32: result = false; break;
235  case treelite::TypeInfo::kUInt32: result = true; break;
236  default: throw model_import_error("Unrecognized Treelite threshold type");
237  }
238  return result;
239  }
240 
245  template <index_type variant_index>
246  auto import_to_specific_variant(index_type target_variant_index,
247  treelite::Model const& tl_model,
248  index_type num_class,
249  index_type num_feature,
250  index_type max_num_categories,
251  std::vector<index_type> const& offsets,
252  index_type align_bytes = index_type{},
254  int device = 0,
256  {
257  auto result = decision_forest_variant{};
258  if constexpr (variant_index != std::variant_size_v<decision_forest_variant>) {
259  if (variant_index == target_variant_index) {
260  using forest_model_t = std::variant_alternative_t<variant_index, decision_forest_variant>;
261  if constexpr (traversal_order ==
263  // Cannot align whole trees with layered traversal order, since trees
264  // are mingled together
265  align_bytes = index_type{};
266  }
267  auto builder =
268  detail::decision_forest_builder<forest_model_t>(max_num_categories, align_bytes);
269  auto node_index = index_type{};
270  ML::experimental::forest::node_for_each<traversal_order>(
271  tl_model,
272  [&builder, &offsets, &node_index](
273  auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
274  if (node.is_leaf()) {
275  auto output = node.get_output();
276  builder.set_output_size(output.size());
277  if (output.size() > index_type{1}) {
278  builder.add_leaf_vector_node(
279  std::begin(output), std::end(output), node.get_treelite_id(), depth);
280  } else {
281  builder.add_node(
282  typename forest_model_t::io_type(output[0]), node.get_treelite_id(), depth, true);
283  }
284  } else {
285  if (node.is_categorical()) {
286  auto categories = node.get_categories();
287  builder.add_categorical_node(std::begin(categories),
288  std::end(categories),
289  node.get_treelite_id(),
290  depth,
291  node.default_distant(),
292  node.get_feature(),
293  offsets[node_index]);
294  } else {
295  builder.add_node(typename forest_model_t::threshold_type(node.threshold()),
296  node.get_treelite_id(),
297  depth,
298  false,
299  node.default_distant(),
300  false,
301  node.get_feature(),
302  offsets[node_index],
303  node.is_inclusive());
304  }
305  }
306  ++node_index;
307  });
308 
309  builder.set_average_factor(get_average_factor(tl_model));
310  builder.set_bias(get_bias(tl_model));
311  auto postproc_params = get_postproc_params(tl_model);
312  builder.set_element_postproc(postproc_params.element);
313  builder.set_row_postproc(postproc_params.row);
314  builder.set_postproc_constant(postproc_params.constant);
315 
316  result.template emplace<variant_index>(
317  builder.get_decision_forest(num_feature, num_class, mem_type, device, stream));
318  } else {
319  result = import_to_specific_variant<variant_index + 1>(target_variant_index,
320  tl_model,
321  num_class,
322  num_feature,
323  max_num_categories,
324  offsets,
325  align_bytes,
326  mem_type,
327  device,
328  stream);
329  }
330  }
331  return result;
332  }
333 
356  auto import(treelite::Model const& tl_model,
357  index_type align_bytes = index_type{},
358  std::optional<bool> use_double_precision = std::nullopt,
360  int device = 0,
362  {
363  ASSERT(tl_model.num_target == 1, "FIL does not support multi-target model");
364  // Check tree annotation (assignment)
365  if (tl_model.task_type == treelite::TaskType::kMultiClf) {
366  // Must be either vector leaf or grove-per-class
367  if (tl_model.leaf_vector_shape[1] > 1) { // vector-leaf
368  ASSERT(tl_model.leaf_vector_shape[1] == int(tl_model.num_class[0]),
369  "Vector leaf must be equal to num_class = %d",
370  tl_model.num_class[0]);
371  auto tree_count = num_trees(tl_model);
372  for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
373  ASSERT(tl_model.class_id[tree_id] == -1, "Tree %d has invalid class assignment", tree_id);
374  }
375  } else { // grove-per-class
376  auto tree_count = num_trees(tl_model);
377  for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
378  ASSERT(tl_model.class_id[tree_id] == int(tree_id % tl_model.num_class[0]),
379  "Tree %d has invalid class assignment",
380  tree_id);
381  }
382  }
383  }
384  // Check base_scores
385  for (std::int32_t class_id = 1; class_id < tl_model.num_class[0]; ++class_id) {
386  ASSERT(tl_model.base_scores[0] == tl_model.base_scores[class_id],
387  "base_scores must be identical for all classes");
388  }
389 
390  auto result = decision_forest_variant{};
391  auto num_feature = get_num_feature(tl_model);
392  auto max_num_categories = get_max_num_categories(tl_model);
393  auto num_categorical_nodes = get_num_categorical_nodes(tl_model);
394  auto num_leaf_vector_nodes = get_num_leaf_vector_nodes(tl_model);
395  auto use_double_thresholds = use_double_precision.value_or(uses_double_thresholds(tl_model));
396 
397  auto offsets = get_offsets(tl_model);
398  auto max_offset = *std::max_element(std::begin(offsets), std::end(offsets));
399 
400  auto variant_index = get_forest_variant_index(use_double_thresholds,
401  max_offset,
402  num_feature,
403  num_categorical_nodes,
404  max_num_categories,
405  num_leaf_vector_nodes,
406  layout);
407  auto num_class = get_num_class(tl_model);
408  return forest_model{import_to_specific_variant<index_type{}>(variant_index,
409  tl_model,
410  num_class,
411  num_feature,
412  max_num_categories,
413  offsets,
414  align_bytes,
415  dev_type,
416  device,
417  stream)};
418  }
419 };
420 
444 auto import_from_treelite_model(treelite::Model const& tl_model,
445  tree_layout layout = preferred_tree_layout,
446  index_type align_bytes = index_type{},
447  std::optional<bool> use_double_precision = std::nullopt,
449  int device = 0,
451 {
452  auto result = forest_model{};
453  switch (layout) {
454  case tree_layout::depth_first:
455  result = treelite_importer<tree_layout::depth_first>{}.import(
456  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
457  break;
458  case tree_layout::breadth_first:
459  result = treelite_importer<tree_layout::breadth_first>{}.import(
460  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
461  break;
462  case tree_layout::layered_children_together:
463  result = treelite_importer<tree_layout::layered_children_together>{}.import(
464  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
465  break;
466  }
467  return result;
468 }
469 
495  tree_layout layout = preferred_tree_layout,
496  index_type align_bytes = index_type{},
497  std::optional<bool> use_double_precision = std::nullopt,
499  int device = 0,
501 {
502  return import_from_treelite_model(*static_cast<treelite::Model*>(tl_handle),
503  layout,
504  align_bytes,
505  use_double_precision,
506  dev_type,
507  device,
508  stream);
509 }
510 
511 } // namespace fil
512 } // namespace experimental
513 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
tree_layout
Definition: tree_layout.hpp:20
element_op
Definition: postproc_ops.hpp:29
uint32_t index_type
Definition: index_type.hpp:21
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:432
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:452
auto import_from_treelite_handle(TreeliteModelHandle tl_handle, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:494
row_op
Definition: postproc_ops.hpp:22
auto import_from_treelite_model(treelite::Model const &tl_model, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:444
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:177
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:188
Definition: dbscan.hpp:30
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: treelite_importer.hpp:42
element_op element
Definition: treelite_importer.hpp:43
double constant
Definition: treelite_importer.hpp:45
row_op row
Definition: treelite_importer.hpp:44
Definition: exceptions.hpp:36
Definition: node.hpp:93
HOST DEVICE constexpr auto is_categorical() const
Definition: node.hpp:166
HOST DEVICE constexpr auto is_leaf() const
Definition: node.hpp:156
Definition: treelite_importer.hpp:55
auto import_to_specific_variant(index_type target_variant_index, treelite::Model const &tl_model, index_type num_class, index_type num_feature, index_type max_num_categories, std::vector< index_type > const &offsets, index_type align_bytes=index_type{}, raft_proto::device_type mem_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:246
auto uses_double_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:217
auto get_num_feature(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:113
auto get_tree_sizes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:100
auto get_num_leaf_vector_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:138
auto get_num_class(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:108
auto get_postproc_params(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:171
static constexpr auto const traversal_order
Definition: treelite_importer.hpp:56
auto get_node_count(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:69
auto get_num_categorical_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:128
auto get_bias(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:166
auto uses_double_thresholds(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:206
auto uses_integer_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:229
auto get_average_factor(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:150
auto get_offsets(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:76
auto get_max_num_categories(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:118
auto num_trees(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:92
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23