treelite_importer.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
27 
28 #include <treelite/c_api.h>
29 #include <treelite/enum/task_type.h>
30 #include <treelite/enum/tree_node_type.h>
31 #include <treelite/enum/typeinfo.h>
32 #include <treelite/tree.h>
33 
34 #include <cmath>
35 #include <variant>
36 
37 namespace ML {
38 namespace experimental {
39 namespace fil {
40 
41 namespace detail {
42 
46  double constant = 1.0;
47 };
48 } // namespace detail
49 
55 template <tree_layout layout>
57  auto static constexpr const traversal_order = []() constexpr {
58  if constexpr (layout == tree_layout::depth_first) {
60  } else if constexpr (layout == tree_layout::breadth_first) {
62  } else if constexpr (layout == tree_layout::layered_children_together) {
64  } else {
65  static_assert(layout == tree_layout::depth_first,
66  "Layout not yet implemented in treelite importer for FIL");
67  }
68  }();
69 
70  auto get_node_count(treelite::Model const& tl_model)
71  {
73  tl_model, index_type{}, [](auto&& count, auto&& tree) { return count + tree.num_nodes; });
74  }
75 
76  /* Return vector of offsets between each node and its most distant child */
77  auto get_offsets(treelite::Model const& tl_model)
78  {
79  auto node_count = get_node_count(tl_model);
80  auto result = std::vector<index_type>(node_count);
81  auto parent_indexes = std::vector<index_type>{};
82  parent_indexes.reserve(node_count);
83  ML::experimental::forest::node_transform<traversal_order>(
84  tl_model,
85  std::back_inserter(parent_indexes),
86  [](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) { return parent_index; });
87  for (auto i = std::size_t{}; i < node_count; ++i) {
88  result[parent_indexes[i]] = i - parent_indexes[i];
89  }
90  return result;
91  }
92 
93  auto num_trees(treelite::Model const& tl_model)
94  {
95  auto result = index_type{};
96  std::visit([&result](auto&& concrete_tl_model) { result = concrete_tl_model.trees.size(); },
97  tl_model.variant_);
98  return result;
99  }
100 
101  auto get_tree_sizes(treelite::Model const& tl_model)
102  {
103  auto result = std::vector<index_type>{};
105  tl_model, std::back_inserter(result), [](auto&& tree) { return tree.num_nodes; });
106  return result;
107  }
108 
109  auto get_num_class(treelite::Model const& tl_model)
110  {
111  return static_cast<index_type>(tl_model.num_class[0]);
112  }
113 
114  auto get_num_feature(treelite::Model const& tl_model)
115  {
116  return static_cast<index_type>(tl_model.num_feature);
117  }
118 
119  auto get_max_num_categories(treelite::Model const& tl_model)
120  {
121  return ML::experimental::forest::node_accumulate<traversal_order>(
122  tl_model,
123  index_type{},
124  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
125  return std::max(cur_accum, static_cast<index_type>(node.max_num_categories()));
126  });
127  }
128 
129  auto get_num_categorical_nodes(treelite::Model const& tl_model)
130  {
131  return ML::experimental::forest::node_accumulate<traversal_order>(
132  tl_model,
133  index_type{},
134  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
135  return cur_accum + static_cast<index_type>(node.is_categorical());
136  });
137  }
138 
139  auto get_num_leaf_vector_nodes(treelite::Model const& tl_model)
140  {
141  return ML::experimental::forest::node_accumulate<traversal_order>(
142  tl_model,
143  index_type{},
144  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
145  auto accum = cur_accum;
146  if (node.is_leaf() && node.get_output().size() > 1) { ++accum; }
147  return accum;
148  });
149  }
150 
151  auto get_average_factor(treelite::Model const& tl_model)
152  {
153  auto result = double{};
154  if (tl_model.average_tree_output) {
155  if (tl_model.task_type == treelite::TaskType::kMultiClf &&
156  tl_model.leaf_vector_shape[1] == 1) { // grove-per-class
157  result = num_trees(tl_model) / tl_model.num_class[0];
158  } else {
159  result = num_trees(tl_model);
160  }
161  } else {
162  result = 1.0;
163  }
164  return result;
165  }
166 
167  auto get_bias(treelite::Model const& tl_model)
168  {
169  return static_cast<double>(tl_model.base_scores[0]);
170  }
171 
172  auto get_postproc_params(treelite::Model const& tl_model)
173  {
174  auto result = detail::postproc_params_t{};
175  auto tl_pred_transform = tl_model.postprocessor;
176  if (tl_pred_transform == std::string{"identity"} ||
177  tl_pred_transform == std::string{"identity_multiclass"}) {
178  result.element = element_op::disable;
179  result.row = row_op::disable;
180  } else if (tl_pred_transform == std::string{"signed_square"}) {
181  result.element = element_op::signed_square;
182  } else if (tl_pred_transform == std::string{"hinge"}) {
183  result.element = element_op::hinge;
184  } else if (tl_pred_transform == std::string{"sigmoid"}) {
185  result.constant = tl_model.sigmoid_alpha;
186  result.element = element_op::sigmoid;
187  } else if (tl_pred_transform == std::string{"exponential"}) {
188  result.element = element_op::exponential;
189  } else if (tl_pred_transform == std::string{"exponential_standard_ratio"}) {
190  result.constant = -tl_model.ratio_c / std::log(2);
191  result.element = element_op::exponential;
192  } else if (tl_pred_transform == std::string{"logarithm_one_plus_exp"}) {
193  result.element = element_op::logarithm_one_plus_exp;
194  } else if (tl_pred_transform == std::string{"max_index"}) {
195  result.row = row_op::max_index;
196  } else if (tl_pred_transform == std::string{"softmax"}) {
197  result.row = row_op::softmax;
198  } else if (tl_pred_transform == std::string{"multiclass_ova"}) {
199  result.constant = tl_model.sigmoid_alpha;
200  result.element = element_op::sigmoid;
201  } else {
202  throw model_import_error{"Unrecognized Treelite pred_transform string"};
203  }
204  return result;
205  }
206 
207  auto uses_double_thresholds(treelite::Model const& tl_model)
208  {
209  auto result = false;
210  switch (tl_model.GetThresholdType()) {
211  case treelite::TypeInfo::kFloat64: result = true; break;
212  case treelite::TypeInfo::kFloat32: result = false; break;
213  default: throw model_import_error("Unrecognized Treelite threshold type");
214  }
215  return result;
216  }
217 
218  auto uses_double_outputs(treelite::Model const& tl_model)
219  {
220  auto result = false;
221  switch (tl_model.GetThresholdType()) {
222  case treelite::TypeInfo::kFloat64: result = true; break;
223  case treelite::TypeInfo::kFloat32: result = false; break;
224  case treelite::TypeInfo::kUInt32: result = false; break;
225  default: throw model_import_error("Unrecognized Treelite threshold type");
226  }
227  return result;
228  }
229 
230  auto uses_integer_outputs(treelite::Model const& tl_model)
231  {
232  auto result = false;
233  switch (tl_model.GetThresholdType()) {
234  case treelite::TypeInfo::kFloat64: result = false; break;
235  case treelite::TypeInfo::kFloat32: result = false; break;
236  case treelite::TypeInfo::kUInt32: result = true; break;
237  default: throw model_import_error("Unrecognized Treelite threshold type");
238  }
239  return result;
240  }
241 
246  template <index_type variant_index>
247  auto import_to_specific_variant(index_type target_variant_index,
248  treelite::Model const& tl_model,
249  index_type num_class,
250  index_type num_feature,
251  index_type max_num_categories,
252  std::vector<index_type> const& offsets,
253  index_type align_bytes = index_type{},
255  int device = 0,
257  {
258  auto result = decision_forest_variant{};
259  if constexpr (variant_index != std::variant_size_v<decision_forest_variant>) {
260  if (variant_index == target_variant_index) {
261  using forest_model_t = std::variant_alternative_t<variant_index, decision_forest_variant>;
262  if constexpr (traversal_order ==
264  // Cannot align whole trees with layered traversal order, since trees
265  // are mingled together
266  align_bytes = index_type{};
267  }
268  auto builder =
269  detail::decision_forest_builder<forest_model_t>(max_num_categories, align_bytes);
270  auto node_index = index_type{};
271  ML::experimental::forest::node_for_each<traversal_order>(
272  tl_model,
273  [&builder, &offsets, &node_index](
274  auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
275  if (node.is_leaf()) {
276  auto output = node.get_output();
277  builder.set_output_size(output.size());
278  if (output.size() > index_type{1}) {
279  builder.add_leaf_vector_node(
280  std::begin(output), std::end(output), node.get_treelite_id(), depth);
281  } else {
282  builder.add_node(
283  typename forest_model_t::io_type(output[0]), node.get_treelite_id(), depth, true);
284  }
285  } else {
286  if (node.is_categorical()) {
287  auto categories = node.get_categories();
288  builder.add_categorical_node(std::begin(categories),
289  std::end(categories),
290  node.get_treelite_id(),
291  depth,
292  node.default_distant(),
293  node.get_feature(),
294  offsets[node_index]);
295  } else {
296  builder.add_node(typename forest_model_t::threshold_type(node.threshold()),
297  node.get_treelite_id(),
298  depth,
299  false,
300  node.default_distant(),
301  false,
302  node.get_feature(),
303  offsets[node_index],
304  node.is_inclusive());
305  }
306  }
307  ++node_index;
308  });
309 
310  builder.set_average_factor(get_average_factor(tl_model));
311  builder.set_bias(get_bias(tl_model));
312  auto postproc_params = get_postproc_params(tl_model);
313  builder.set_element_postproc(postproc_params.element);
314  builder.set_row_postproc(postproc_params.row);
315  builder.set_postproc_constant(postproc_params.constant);
316 
317  result.template emplace<variant_index>(
318  builder.get_decision_forest(num_feature, num_class, mem_type, device, stream));
319  } else {
320  result = import_to_specific_variant<variant_index + 1>(target_variant_index,
321  tl_model,
322  num_class,
323  num_feature,
324  max_num_categories,
325  offsets,
326  align_bytes,
327  mem_type,
328  device,
329  stream);
330  }
331  }
332  return result;
333  }
334 
357  forest_model import(treelite::Model const& tl_model,
358  index_type align_bytes = index_type{},
359  std::optional<bool> use_double_precision = std::nullopt,
361  int device = 0,
363  {
364  // Handle degenerate trees (a single root node with no child)
365  if (auto processed_tl_model = detail::convert_degenerate_trees(tl_model); processed_tl_model) {
366  return import(
367  *processed_tl_model.get(), align_bytes, use_double_precision, dev_type, device, stream);
368  }
369 
370  ASSERT(tl_model.num_target == 1, "FIL does not support multi-target model");
371  // Check tree annotation (assignment)
372  if (tl_model.task_type == treelite::TaskType::kMultiClf) {
373  // Must be either vector leaf or grove-per-class
374  if (tl_model.leaf_vector_shape[1] > 1) { // vector-leaf
375  ASSERT(tl_model.leaf_vector_shape[1] == int(tl_model.num_class[0]),
376  "Vector leaf must be equal to num_class = %d",
377  tl_model.num_class[0]);
378  auto tree_count = num_trees(tl_model);
379  for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
380  ASSERT(tl_model.class_id[tree_id] == -1, "Tree %d has invalid class assignment", tree_id);
381  }
382  } else { // grove-per-class
383  auto tree_count = num_trees(tl_model);
384  for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
385  ASSERT(tl_model.class_id[tree_id] == int(tree_id % tl_model.num_class[0]),
386  "Tree %d has invalid class assignment",
387  tree_id);
388  }
389  }
390  }
391  // Check base_scores
392  for (std::int32_t class_id = 1; class_id < tl_model.num_class[0]; ++class_id) {
393  ASSERT(tl_model.base_scores[0] == tl_model.base_scores[class_id],
394  "base_scores must be identical for all classes");
395  }
396 
397  auto result = decision_forest_variant{};
398  auto num_feature = get_num_feature(tl_model);
399  auto max_num_categories = get_max_num_categories(tl_model);
400  auto num_categorical_nodes = get_num_categorical_nodes(tl_model);
401  auto num_leaf_vector_nodes = get_num_leaf_vector_nodes(tl_model);
402  auto use_double_thresholds = use_double_precision.value_or(uses_double_thresholds(tl_model));
403 
404  auto offsets = get_offsets(tl_model);
405  auto max_offset = *std::max_element(std::begin(offsets), std::end(offsets));
406 
407  auto variant_index = get_forest_variant_index(use_double_thresholds,
408  max_offset,
409  num_feature,
410  num_categorical_nodes,
411  max_num_categories,
412  num_leaf_vector_nodes,
413  layout);
414  auto num_class = get_num_class(tl_model);
415  return forest_model{import_to_specific_variant<index_type{}>(variant_index,
416  tl_model,
417  num_class,
418  num_feature,
419  max_num_categories,
420  offsets,
421  align_bytes,
422  dev_type,
423  device,
424  stream)};
425  }
426 };
427 
451 auto import_from_treelite_model(treelite::Model const& tl_model,
452  tree_layout layout = preferred_tree_layout,
453  index_type align_bytes = index_type{},
454  std::optional<bool> use_double_precision = std::nullopt,
456  int device = 0,
458 {
459  auto result = forest_model{};
460  switch (layout) {
461  case tree_layout::depth_first:
462  result = treelite_importer<tree_layout::depth_first>{}.import(
463  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
464  break;
465  case tree_layout::breadth_first:
466  result = treelite_importer<tree_layout::breadth_first>{}.import(
467  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
468  break;
469  case tree_layout::layered_children_together:
470  result = treelite_importer<tree_layout::layered_children_together>{}.import(
471  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
472  break;
473  }
474  return result;
475 }
476 
502  tree_layout layout = preferred_tree_layout,
503  index_type align_bytes = index_type{},
504  std::optional<bool> use_double_precision = std::nullopt,
506  int device = 0,
508 {
509  return import_from_treelite_model(*static_cast<treelite::Model*>(tl_handle),
510  layout,
511  align_bytes,
512  use_double_precision,
513  dev_type,
514  device,
515  stream);
516 }
517 
518 } // namespace fil
519 } // namespace experimental
520 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
std::unique_ptr< treelite::Model > convert_degenerate_trees(treelite::Model const &tl_model)
Definition: degenerate_trees.hpp:31
tree_layout
Definition: tree_layout.hpp:20
element_op
Definition: postproc_ops.hpp:29
uint32_t index_type
Definition: index_type.hpp:21
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:432
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:452
auto import_from_treelite_handle(TreeliteModelHandle tl_handle, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:501
row_op
Definition: postproc_ops.hpp:22
auto import_from_treelite_model(treelite::Model const &tl_model, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:451
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:177
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:188
Definition: dbscan.hpp:30
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: treelite_importer.hpp:43
element_op element
Definition: treelite_importer.hpp:44
double constant
Definition: treelite_importer.hpp:46
row_op row
Definition: treelite_importer.hpp:45
Definition: forest_model.hpp:38
Definition: exceptions.hpp:36
Definition: node.hpp:93
HOST DEVICE constexpr auto is_categorical() const
Definition: node.hpp:166
HOST DEVICE constexpr auto is_leaf() const
Definition: node.hpp:156
Definition: treelite_importer.hpp:56
auto import_to_specific_variant(index_type target_variant_index, treelite::Model const &tl_model, index_type num_class, index_type num_feature, index_type max_num_categories, std::vector< index_type > const &offsets, index_type align_bytes=index_type{}, raft_proto::device_type mem_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:247
auto uses_double_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:218
auto get_num_feature(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:114
auto get_tree_sizes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:101
auto get_num_leaf_vector_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:139
auto get_num_class(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:109
auto get_postproc_params(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:172
static constexpr auto const traversal_order
Definition: treelite_importer.hpp:57
auto get_node_count(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:70
auto get_num_categorical_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:129
auto get_bias(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:167
auto uses_double_thresholds(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:207
auto uses_integer_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:230
auto get_average_factor(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:151
auto get_offsets(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:77
auto get_max_num_categories(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:119
auto num_trees(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:93
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23