treelite_importer.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 #include <cuml/fil/constants.hpp>
11 #include <cuml/fil/exceptions.hpp>
14 #include <cuml/fil/tree_layout.hpp>
16 
17 #include <raft/core/error.hpp>
18 
19 #include <treelite/c_api.h>
20 #include <treelite/enum/task_type.h>
21 #include <treelite/enum/tree_node_type.h>
22 #include <treelite/enum/typeinfo.h>
23 #include <treelite/tree.h>
24 
25 #include <cmath>
26 #include <variant>
27 
28 namespace ML {
29 namespace fil {
30 
31 namespace detail {
32 
36  double constant = 1.0;
37 };
38 } // namespace detail
39 
45 template <tree_layout layout>
47  auto static constexpr const traversal_order = []() constexpr {
48  if constexpr (layout == tree_layout::depth_first) {
50  } else if constexpr (layout == tree_layout::breadth_first) {
52  } else if constexpr (layout == tree_layout::layered_children_together) {
54  } else {
55  static_assert(layout == tree_layout::depth_first,
56  "Layout not yet implemented in treelite importer for FIL");
57  }
58  }();
59 
60  auto get_node_count(treelite::Model const& tl_model)
61  {
63  tl_model, index_type{}, [](auto&& count, auto&& tree) { return count + tree.num_nodes; });
64  }
65 
66  /* Return vector of offsets between each node and its most distant child */
67  auto get_offsets(treelite::Model const& tl_model)
68  {
69  auto node_count = get_node_count(tl_model);
70  auto result = std::vector<index_type>(node_count);
71  auto parent_indexes = std::vector<index_type>{};
72  parent_indexes.reserve(node_count);
73  ML::forest::node_transform<traversal_order>(
74  tl_model,
75  std::back_inserter(parent_indexes),
76  [](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) { return parent_index; });
77  for (auto i = std::size_t{}; i < node_count; ++i) {
78  result[parent_indexes[i]] = i - parent_indexes[i];
79  }
80  return result;
81  }
82 
83  auto num_trees(treelite::Model const& tl_model)
84  {
85  auto result = index_type{};
86  std::visit([&result](auto&& concrete_tl_model) { result = concrete_tl_model.trees.size(); },
87  tl_model.variant_);
88  return result;
89  }
90 
91  auto get_tree_sizes(treelite::Model const& tl_model)
92  {
93  auto result = std::vector<index_type>{};
95  tl_model, std::back_inserter(result), [](auto&& tree) { return tree.num_nodes; });
96  return result;
97  }
98 
99  auto get_num_class(treelite::Model const& tl_model)
100  {
101  return static_cast<index_type>(tl_model.num_class[0]);
102  }
103 
104  auto get_num_feature(treelite::Model const& tl_model)
105  {
106  return static_cast<index_type>(tl_model.num_feature);
107  }
108 
109  auto get_max_num_categories(treelite::Model const& tl_model)
110  {
111  return ML::forest::node_accumulate<traversal_order>(
112  tl_model,
113  index_type{},
114  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
115  return std::max(cur_accum, static_cast<index_type>(node.max_num_categories()));
116  });
117  }
118 
119  auto get_num_categorical_nodes(treelite::Model const& tl_model)
120  {
121  return ML::forest::node_accumulate<traversal_order>(
122  tl_model,
123  index_type{},
124  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
125  return cur_accum + static_cast<index_type>(node.is_categorical());
126  });
127  }
128 
129  auto get_num_leaf_vector_nodes(treelite::Model const& tl_model)
130  {
131  return ML::forest::node_accumulate<traversal_order>(
132  tl_model,
133  index_type{},
134  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
135  auto accum = cur_accum;
136  if (node.is_leaf() && node.get_output().size() > 1) { ++accum; }
137  return accum;
138  });
139  }
140 
141  auto get_average_factor(treelite::Model const& tl_model)
142  {
143  auto result = double{};
144  if (tl_model.average_tree_output) {
145  if (tl_model.task_type == treelite::TaskType::kMultiClf &&
146  tl_model.leaf_vector_shape[1] == 1) { // grove-per-class
147  result = num_trees(tl_model) / tl_model.num_class[0];
148  } else {
149  result = num_trees(tl_model);
150  }
151  } else {
152  result = 1.0;
153  }
154  return result;
155  }
156 
157  auto get_bias(treelite::Model const& tl_model) { return tl_model.base_scores.AsVector(); }
158 
159  auto get_postproc_params(treelite::Model const& tl_model)
160  {
161  auto result = detail::postproc_params_t{};
162  auto tl_pred_transform = tl_model.postprocessor;
163  if (tl_pred_transform == std::string{"identity"} ||
164  tl_pred_transform == std::string{"identity_multiclass"}) {
165  result.element = element_op::disable;
166  result.row = row_op::disable;
167  } else if (tl_pred_transform == std::string{"signed_square"}) {
168  result.element = element_op::signed_square;
169  } else if (tl_pred_transform == std::string{"hinge"}) {
170  result.element = element_op::hinge;
171  } else if (tl_pred_transform == std::string{"sigmoid"}) {
172  result.constant = tl_model.sigmoid_alpha;
173  result.element = element_op::sigmoid;
174  } else if (tl_pred_transform == std::string{"exponential"}) {
175  result.element = element_op::exponential;
176  } else if (tl_pred_transform == std::string{"exponential_standard_ratio"}) {
177  result.constant = -tl_model.ratio_c / std::log(2);
178  result.element = element_op::exponential;
179  } else if (tl_pred_transform == std::string{"logarithm_one_plus_exp"}) {
180  result.element = element_op::logarithm_one_plus_exp;
181  } else if (tl_pred_transform == std::string{"max_index"}) {
182  result.row = row_op::max_index;
183  } else if (tl_pred_transform == std::string{"softmax"}) {
184  result.row = row_op::softmax;
185  } else if (tl_pred_transform == std::string{"multiclass_ova"}) {
186  result.constant = tl_model.sigmoid_alpha;
187  result.element = element_op::sigmoid;
188  } else {
189  throw model_import_error{"Unrecognized Treelite pred_transform string"};
190  }
191  return result;
192  }
193 
194  auto uses_double_thresholds(treelite::Model const& tl_model)
195  {
196  auto result = false;
197  switch (tl_model.GetThresholdType()) {
198  case treelite::TypeInfo::kFloat64: result = true; break;
199  case treelite::TypeInfo::kFloat32: result = false; break;
200  default: throw model_import_error("Unrecognized Treelite threshold type");
201  }
202  return result;
203  }
204 
205  auto uses_double_outputs(treelite::Model const& tl_model)
206  {
207  auto result = false;
208  switch (tl_model.GetThresholdType()) {
209  case treelite::TypeInfo::kFloat64: result = true; break;
210  case treelite::TypeInfo::kFloat32: result = false; break;
211  case treelite::TypeInfo::kUInt32: result = false; break;
212  default: throw model_import_error("Unrecognized Treelite threshold type");
213  }
214  return result;
215  }
216 
217  auto uses_integer_outputs(treelite::Model const& tl_model)
218  {
219  auto result = false;
220  switch (tl_model.GetThresholdType()) {
221  case treelite::TypeInfo::kFloat64: result = false; break;
222  case treelite::TypeInfo::kFloat32: result = false; break;
223  case treelite::TypeInfo::kUInt32: result = true; break;
224  default: throw model_import_error("Unrecognized Treelite threshold type");
225  }
226  return result;
227  }
228 
233  template <index_type variant_index>
234  auto import_to_specific_variant(index_type target_variant_index,
235  treelite::Model const& tl_model,
236  index_type num_class,
237  index_type num_feature,
238  index_type max_num_categories,
239  std::vector<index_type> const& offsets,
240  index_type align_bytes = index_type{},
242  int device = 0,
244  {
245  auto result = decision_forest_variant{};
246  if constexpr (variant_index != std::variant_size_v<decision_forest_variant>) {
247  if (variant_index == target_variant_index) {
248  using forest_model_t = std::variant_alternative_t<variant_index, decision_forest_variant>;
250  // Cannot align whole trees with layered traversal order, since trees
251  // are mingled together
252  align_bytes = index_type{};
253  }
254  auto builder =
255  detail::decision_forest_builder<forest_model_t>(max_num_categories, align_bytes);
256  auto node_index = index_type{};
257  ML::forest::node_for_each<traversal_order>(
258  tl_model,
259  [&builder, &offsets, &node_index](
260  auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
261  if (node.is_leaf()) {
262  auto output = node.get_output();
263  builder.set_output_size(output.size());
264  if (output.size() > index_type{1}) {
265  builder.add_leaf_vector_node(
266  std::begin(output), std::end(output), node.get_treelite_id(), depth);
267  } else {
268  builder.add_node(
269  typename forest_model_t::io_type(output[0]), node.get_treelite_id(), depth, true);
270  }
271  } else {
272  if (node.is_categorical()) {
273  auto categories = node.get_categories();
274  builder.add_categorical_node(std::begin(categories),
275  std::end(categories),
276  node.get_treelite_id(),
277  depth,
278  node.default_distant(),
279  node.get_feature(),
280  offsets[node_index]);
281  } else {
282  builder.add_node(typename forest_model_t::threshold_type(node.threshold()),
283  node.get_treelite_id(),
284  depth,
285  false,
286  node.default_distant(),
287  false,
288  node.get_feature(),
289  offsets[node_index],
290  node.is_inclusive());
291  }
292  }
293  ++node_index;
294  });
295 
296  builder.set_average_factor(get_average_factor(tl_model));
297  builder.set_bias(get_bias(tl_model));
298  auto postproc_params = get_postproc_params(tl_model);
299  builder.set_element_postproc(postproc_params.element);
300  builder.set_row_postproc(postproc_params.row);
301  builder.set_postproc_constant(postproc_params.constant);
302 
303  result.template emplace<variant_index>(
304  builder.get_decision_forest(num_feature, num_class, mem_type, device, stream));
305  } else {
306  result = import_to_specific_variant<variant_index + 1>(target_variant_index,
307  tl_model,
308  num_class,
309  num_feature,
310  max_num_categories,
311  offsets,
312  align_bytes,
313  mem_type,
314  device,
315  stream);
316  }
317  }
318  return result;
319  }
320 
343  forest_model import(treelite::Model const& tl_model,
344  index_type align_bytes = index_type{},
345  std::optional<bool> use_double_precision = std::nullopt,
347  int device = 0,
349  {
350  // Handle degenerate trees (a single root node with no child)
351  if (auto processed_tl_model = detail::convert_degenerate_trees(tl_model); processed_tl_model) {
352  return import(
353  *processed_tl_model.get(), align_bytes, use_double_precision, dev_type, device, stream);
354  }
355 
356  ASSERT(tl_model.num_target == 1, "FIL does not support multi-target model");
357  // Check tree annotation (assignment)
358  if (tl_model.task_type == treelite::TaskType::kMultiClf) {
359  // Must be either vector leaf or grove-per-class
360  if (tl_model.leaf_vector_shape[1] > 1) { // vector-leaf
361  ASSERT(tl_model.leaf_vector_shape[1] == int(tl_model.num_class[0]),
362  "Vector leaf must be equal to num_class = %d",
363  tl_model.num_class[0]);
364  auto tree_count = num_trees(tl_model);
365  for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
366  ASSERT(tl_model.class_id[tree_id] == -1, "Tree %d has invalid class assignment", tree_id);
367  }
368  } else { // grove-per-class
369  auto tree_count = num_trees(tl_model);
370  for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
371  ASSERT(tl_model.class_id[tree_id] == int(tree_id % tl_model.num_class[0]),
372  "Tree %d has invalid class assignment",
373  tree_id);
374  }
375  }
376  }
377 
378  auto result = decision_forest_variant{};
379  auto num_feature = get_num_feature(tl_model);
380  auto max_num_categories = get_max_num_categories(tl_model);
381  auto num_categorical_nodes = get_num_categorical_nodes(tl_model);
382  auto num_leaf_vector_nodes = get_num_leaf_vector_nodes(tl_model);
383  auto use_double_thresholds = use_double_precision.value_or(uses_double_thresholds(tl_model));
384 
385  auto offsets = get_offsets(tl_model);
386  auto max_offset = *std::max_element(std::begin(offsets), std::end(offsets));
387 
388  auto variant_index = get_forest_variant_index(use_double_thresholds,
389  max_offset,
390  num_feature,
391  num_categorical_nodes,
392  max_num_categories,
393  num_leaf_vector_nodes,
394  layout);
395  auto num_class = get_num_class(tl_model);
396  return forest_model{import_to_specific_variant<index_type{}>(variant_index,
397  tl_model,
398  num_class,
399  num_feature,
400  max_num_categories,
401  offsets,
402  align_bytes,
403  dev_type,
404  device,
405  stream)};
406  }
407 };
408 
432 auto import_from_treelite_model(treelite::Model const& tl_model,
433  tree_layout layout = preferred_tree_layout,
434  index_type align_bytes = index_type{},
435  std::optional<bool> use_double_precision = std::nullopt,
437  int device = 0,
439 {
440  auto result = forest_model{};
441  switch (layout) {
442  case tree_layout::depth_first:
443  result = treelite_importer<tree_layout::depth_first>{}.import(
444  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
445  break;
446  case tree_layout::breadth_first:
447  result = treelite_importer<tree_layout::breadth_first>{}.import(
448  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
449  break;
450  case tree_layout::layered_children_together:
451  result = treelite_importer<tree_layout::layered_children_together>{}.import(
452  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
453  break;
454  }
455  return result;
456 }
457 
483  tree_layout layout = preferred_tree_layout,
484  index_type align_bytes = index_type{},
485  std::optional<bool> use_double_precision = std::nullopt,
487  int device = 0,
489 {
490  return import_from_treelite_model(*static_cast<treelite::Model*>(tl_handle),
491  layout,
492  align_bytes,
493  use_double_precision,
494  dev_type,
495  device,
496  stream);
497 }
498 
499 } // namespace fil
500 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
std::unique_ptr< treelite::Model > convert_degenerate_trees(treelite::Model const &tl_model)
Definition: degenerate_trees.hpp:20
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:445
tree_layout
Definition: tree_layout.hpp:8
row_op
Definition: postproc_ops.hpp:10
element_op
Definition: postproc_ops.hpp:17
auto import_from_treelite_handle(TreeliteModelHandle tl_handle, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:482
auto import_from_treelite_model(treelite::Model const &tl_model, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:432
uint32_t index_type
Definition: index_type.hpp:9
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:425
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:183
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:172
Definition: dbscan.hpp:18
int cuda_stream
Definition: cuda_stream.hpp:14
device_type
Definition: device_type.hpp:7
Definition: treelite_importer.hpp:33
element_op element
Definition: treelite_importer.hpp:34
row_op row
Definition: treelite_importer.hpp:35
double constant
Definition: treelite_importer.hpp:36
Definition: forest_model.hpp:29
Definition: exceptions.hpp:24
Definition: node.hpp:81
HOST DEVICE constexpr auto is_categorical() const
Definition: node.hpp:154
HOST DEVICE constexpr auto is_leaf() const
Definition: node.hpp:144
Definition: treelite_importer.hpp:46
auto uses_integer_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:217
auto get_postproc_params(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:159
auto get_num_feature(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:104
auto get_max_num_categories(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:109
auto uses_double_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:205
auto get_bias(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:157
static constexpr auto const traversal_order
Definition: treelite_importer.hpp:47
auto get_num_leaf_vector_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:129
auto import_to_specific_variant(index_type target_variant_index, treelite::Model const &tl_model, index_type num_class, index_type num_feature, index_type max_num_categories, std::vector< index_type > const &offsets, index_type align_bytes=index_type{}, raft_proto::device_type mem_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:234
auto uses_double_thresholds(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:194
auto num_trees(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:83
auto get_num_categorical_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:119
auto get_tree_sizes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:91
auto get_offsets(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:67
auto get_node_count(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:60
auto get_num_class(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:99
auto get_average_factor(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:141
void * TreeliteModelHandle
Definition: treelite_defs.hpp:12