treelite_importer.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 #include <cuml/fil/constants.hpp>
22 #include <cuml/fil/exceptions.hpp>
25 #include <cuml/fil/tree_layout.hpp>
27 
28 #include <treelite/c_api.h>
29 #include <treelite/enum/task_type.h>
30 #include <treelite/enum/tree_node_type.h>
31 #include <treelite/enum/typeinfo.h>
32 #include <treelite/tree.h>
33 
34 #include <cmath>
35 #include <variant>
36 
37 namespace ML {
38 namespace fil {
39 
40 namespace detail {
41 
45  double constant = 1.0;
46 };
47 } // namespace detail
48 
54 template <tree_layout layout>
56  auto static constexpr const traversal_order = []() constexpr {
57  if constexpr (layout == tree_layout::depth_first) {
59  } else if constexpr (layout == tree_layout::breadth_first) {
61  } else if constexpr (layout == tree_layout::layered_children_together) {
63  } else {
64  static_assert(layout == tree_layout::depth_first,
65  "Layout not yet implemented in treelite importer for FIL");
66  }
67  }();
68 
69  auto get_node_count(treelite::Model const& tl_model)
70  {
72  tl_model, index_type{}, [](auto&& count, auto&& tree) { return count + tree.num_nodes; });
73  }
74 
75  /* Return vector of offsets between each node and its most distant child */
76  auto get_offsets(treelite::Model const& tl_model)
77  {
78  auto node_count = get_node_count(tl_model);
79  auto result = std::vector<index_type>(node_count);
80  auto parent_indexes = std::vector<index_type>{};
81  parent_indexes.reserve(node_count);
82  ML::forest::node_transform<traversal_order>(
83  tl_model,
84  std::back_inserter(parent_indexes),
85  [](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) { return parent_index; });
86  for (auto i = std::size_t{}; i < node_count; ++i) {
87  result[parent_indexes[i]] = i - parent_indexes[i];
88  }
89  return result;
90  }
91 
92  auto num_trees(treelite::Model const& tl_model)
93  {
94  auto result = index_type{};
95  std::visit([&result](auto&& concrete_tl_model) { result = concrete_tl_model.trees.size(); },
96  tl_model.variant_);
97  return result;
98  }
99 
100  auto get_tree_sizes(treelite::Model const& tl_model)
101  {
102  auto result = std::vector<index_type>{};
104  tl_model, std::back_inserter(result), [](auto&& tree) { return tree.num_nodes; });
105  return result;
106  }
107 
108  auto get_num_class(treelite::Model const& tl_model)
109  {
110  return static_cast<index_type>(tl_model.num_class[0]);
111  }
112 
113  auto get_num_feature(treelite::Model const& tl_model)
114  {
115  return static_cast<index_type>(tl_model.num_feature);
116  }
117 
118  auto get_max_num_categories(treelite::Model const& tl_model)
119  {
120  return ML::forest::node_accumulate<traversal_order>(
121  tl_model,
122  index_type{},
123  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
124  return std::max(cur_accum, static_cast<index_type>(node.max_num_categories()));
125  });
126  }
127 
128  auto get_num_categorical_nodes(treelite::Model const& tl_model)
129  {
130  return ML::forest::node_accumulate<traversal_order>(
131  tl_model,
132  index_type{},
133  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
134  return cur_accum + static_cast<index_type>(node.is_categorical());
135  });
136  }
137 
138  auto get_num_leaf_vector_nodes(treelite::Model const& tl_model)
139  {
140  return ML::forest::node_accumulate<traversal_order>(
141  tl_model,
142  index_type{},
143  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
144  auto accum = cur_accum;
145  if (node.is_leaf() && node.get_output().size() > 1) { ++accum; }
146  return accum;
147  });
148  }
149 
150  auto get_average_factor(treelite::Model const& tl_model)
151  {
152  auto result = double{};
153  if (tl_model.average_tree_output) {
154  if (tl_model.task_type == treelite::TaskType::kMultiClf &&
155  tl_model.leaf_vector_shape[1] == 1) { // grove-per-class
156  result = num_trees(tl_model) / tl_model.num_class[0];
157  } else {
158  result = num_trees(tl_model);
159  }
160  } else {
161  result = 1.0;
162  }
163  return result;
164  }
165 
166  auto get_bias(treelite::Model const& tl_model)
167  {
168  return static_cast<double>(tl_model.base_scores[0]);
169  }
170 
171  auto get_postproc_params(treelite::Model const& tl_model)
172  {
173  auto result = detail::postproc_params_t{};
174  auto tl_pred_transform = tl_model.postprocessor;
175  if (tl_pred_transform == std::string{"identity"} ||
176  tl_pred_transform == std::string{"identity_multiclass"}) {
177  result.element = element_op::disable;
178  result.row = row_op::disable;
179  } else if (tl_pred_transform == std::string{"signed_square"}) {
180  result.element = element_op::signed_square;
181  } else if (tl_pred_transform == std::string{"hinge"}) {
182  result.element = element_op::hinge;
183  } else if (tl_pred_transform == std::string{"sigmoid"}) {
184  result.constant = tl_model.sigmoid_alpha;
185  result.element = element_op::sigmoid;
186  } else if (tl_pred_transform == std::string{"exponential"}) {
187  result.element = element_op::exponential;
188  } else if (tl_pred_transform == std::string{"exponential_standard_ratio"}) {
189  result.constant = -tl_model.ratio_c / std::log(2);
190  result.element = element_op::exponential;
191  } else if (tl_pred_transform == std::string{"logarithm_one_plus_exp"}) {
192  result.element = element_op::logarithm_one_plus_exp;
193  } else if (tl_pred_transform == std::string{"max_index"}) {
194  result.row = row_op::max_index;
195  } else if (tl_pred_transform == std::string{"softmax"}) {
196  result.row = row_op::softmax;
197  } else if (tl_pred_transform == std::string{"multiclass_ova"}) {
198  result.constant = tl_model.sigmoid_alpha;
199  result.element = element_op::sigmoid;
200  } else {
201  throw model_import_error{"Unrecognized Treelite pred_transform string"};
202  }
203  return result;
204  }
205 
206  auto uses_double_thresholds(treelite::Model const& tl_model)
207  {
208  auto result = false;
209  switch (tl_model.GetThresholdType()) {
210  case treelite::TypeInfo::kFloat64: result = true; break;
211  case treelite::TypeInfo::kFloat32: result = false; break;
212  default: throw model_import_error("Unrecognized Treelite threshold type");
213  }
214  return result;
215  }
216 
217  auto uses_double_outputs(treelite::Model const& tl_model)
218  {
219  auto result = false;
220  switch (tl_model.GetThresholdType()) {
221  case treelite::TypeInfo::kFloat64: result = true; break;
222  case treelite::TypeInfo::kFloat32: result = false; break;
223  case treelite::TypeInfo::kUInt32: result = false; break;
224  default: throw model_import_error("Unrecognized Treelite threshold type");
225  }
226  return result;
227  }
228 
229  auto uses_integer_outputs(treelite::Model const& tl_model)
230  {
231  auto result = false;
232  switch (tl_model.GetThresholdType()) {
233  case treelite::TypeInfo::kFloat64: result = false; break;
234  case treelite::TypeInfo::kFloat32: result = false; break;
235  case treelite::TypeInfo::kUInt32: result = true; break;
236  default: throw model_import_error("Unrecognized Treelite threshold type");
237  }
238  return result;
239  }
240 
245  template <index_type variant_index>
246  auto import_to_specific_variant(index_type target_variant_index,
247  treelite::Model const& tl_model,
248  index_type num_class,
249  index_type num_feature,
250  index_type max_num_categories,
251  std::vector<index_type> const& offsets,
252  index_type align_bytes = index_type{},
254  int device = 0,
256  {
257  auto result = decision_forest_variant{};
258  if constexpr (variant_index != std::variant_size_v<decision_forest_variant>) {
259  if (variant_index == target_variant_index) {
260  using forest_model_t = std::variant_alternative_t<variant_index, decision_forest_variant>;
262  // Cannot align whole trees with layered traversal order, since trees
263  // are mingled together
264  align_bytes = index_type{};
265  }
266  auto builder =
267  detail::decision_forest_builder<forest_model_t>(max_num_categories, align_bytes);
268  auto node_index = index_type{};
269  ML::forest::node_for_each<traversal_order>(
270  tl_model,
271  [&builder, &offsets, &node_index](
272  auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
273  if (node.is_leaf()) {
274  auto output = node.get_output();
275  builder.set_output_size(output.size());
276  if (output.size() > index_type{1}) {
277  builder.add_leaf_vector_node(
278  std::begin(output), std::end(output), node.get_treelite_id(), depth);
279  } else {
280  builder.add_node(
281  typename forest_model_t::io_type(output[0]), node.get_treelite_id(), depth, true);
282  }
283  } else {
284  if (node.is_categorical()) {
285  auto categories = node.get_categories();
286  builder.add_categorical_node(std::begin(categories),
287  std::end(categories),
288  node.get_treelite_id(),
289  depth,
290  node.default_distant(),
291  node.get_feature(),
292  offsets[node_index]);
293  } else {
294  builder.add_node(typename forest_model_t::threshold_type(node.threshold()),
295  node.get_treelite_id(),
296  depth,
297  false,
298  node.default_distant(),
299  false,
300  node.get_feature(),
301  offsets[node_index],
302  node.is_inclusive());
303  }
304  }
305  ++node_index;
306  });
307 
308  builder.set_average_factor(get_average_factor(tl_model));
309  builder.set_bias(get_bias(tl_model));
310  auto postproc_params = get_postproc_params(tl_model);
311  builder.set_element_postproc(postproc_params.element);
312  builder.set_row_postproc(postproc_params.row);
313  builder.set_postproc_constant(postproc_params.constant);
314 
315  result.template emplace<variant_index>(
316  builder.get_decision_forest(num_feature, num_class, mem_type, device, stream));
317  } else {
318  result = import_to_specific_variant<variant_index + 1>(target_variant_index,
319  tl_model,
320  num_class,
321  num_feature,
322  max_num_categories,
323  offsets,
324  align_bytes,
325  mem_type,
326  device,
327  stream);
328  }
329  }
330  return result;
331  }
332 
355  forest_model import(treelite::Model const& tl_model,
356  index_type align_bytes = index_type{},
357  std::optional<bool> use_double_precision = std::nullopt,
359  int device = 0,
361  {
362  // Handle degenerate trees (a single root node with no child)
363  if (auto processed_tl_model = detail::convert_degenerate_trees(tl_model); processed_tl_model) {
364  return import(
365  *processed_tl_model.get(), align_bytes, use_double_precision, dev_type, device, stream);
366  }
367 
368  ASSERT(tl_model.num_target == 1, "FIL does not support multi-target model");
369  // Check tree annotation (assignment)
370  if (tl_model.task_type == treelite::TaskType::kMultiClf) {
371  // Must be either vector leaf or grove-per-class
372  if (tl_model.leaf_vector_shape[1] > 1) { // vector-leaf
373  ASSERT(tl_model.leaf_vector_shape[1] == int(tl_model.num_class[0]),
374  "Vector leaf must be equal to num_class = %d",
375  tl_model.num_class[0]);
376  auto tree_count = num_trees(tl_model);
377  for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
378  ASSERT(tl_model.class_id[tree_id] == -1, "Tree %d has invalid class assignment", tree_id);
379  }
380  } else { // grove-per-class
381  auto tree_count = num_trees(tl_model);
382  for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
383  ASSERT(tl_model.class_id[tree_id] == int(tree_id % tl_model.num_class[0]),
384  "Tree %d has invalid class assignment",
385  tree_id);
386  }
387  }
388  }
389  // Check base_scores
390  for (std::int32_t class_id = 1; class_id < tl_model.num_class[0]; ++class_id) {
391  ASSERT(tl_model.base_scores[0] == tl_model.base_scores[class_id],
392  "base_scores must be identical for all classes");
393  }
394 
395  auto result = decision_forest_variant{};
396  auto num_feature = get_num_feature(tl_model);
397  auto max_num_categories = get_max_num_categories(tl_model);
398  auto num_categorical_nodes = get_num_categorical_nodes(tl_model);
399  auto num_leaf_vector_nodes = get_num_leaf_vector_nodes(tl_model);
400  auto use_double_thresholds = use_double_precision.value_or(uses_double_thresholds(tl_model));
401 
402  auto offsets = get_offsets(tl_model);
403  auto max_offset = *std::max_element(std::begin(offsets), std::end(offsets));
404 
405  auto variant_index = get_forest_variant_index(use_double_thresholds,
406  max_offset,
407  num_feature,
408  num_categorical_nodes,
409  max_num_categories,
410  num_leaf_vector_nodes,
411  layout);
412  auto num_class = get_num_class(tl_model);
413  return forest_model{import_to_specific_variant<index_type{}>(variant_index,
414  tl_model,
415  num_class,
416  num_feature,
417  max_num_categories,
418  offsets,
419  align_bytes,
420  dev_type,
421  device,
422  stream)};
423  }
424 };
425 
449 auto import_from_treelite_model(treelite::Model const& tl_model,
450  tree_layout layout = preferred_tree_layout,
451  index_type align_bytes = index_type{},
452  std::optional<bool> use_double_precision = std::nullopt,
454  int device = 0,
456 {
457  auto result = forest_model{};
458  switch (layout) {
459  case tree_layout::depth_first:
460  result = treelite_importer<tree_layout::depth_first>{}.import(
461  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
462  break;
463  case tree_layout::breadth_first:
464  result = treelite_importer<tree_layout::breadth_first>{}.import(
465  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
466  break;
467  case tree_layout::layered_children_together:
468  result = treelite_importer<tree_layout::layered_children_together>{}.import(
469  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
470  break;
471  }
472  return result;
473 }
474 
500  tree_layout layout = preferred_tree_layout,
501  index_type align_bytes = index_type{},
502  std::optional<bool> use_double_precision = std::nullopt,
504  int device = 0,
506 {
507  return import_from_treelite_model(*static_cast<treelite::Model*>(tl_handle),
508  layout,
509  align_bytes,
510  use_double_precision,
511  dev_type,
512  device,
513  stream);
514 }
515 
516 } // namespace fil
517 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
std::unique_ptr< treelite::Model > convert_degenerate_trees(treelite::Model const &tl_model)
Definition: degenerate_trees.hpp:31
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:451
tree_layout
Definition: tree_layout.hpp:19
row_op
Definition: postproc_ops.hpp:21
element_op
Definition: postproc_ops.hpp:28
auto import_from_treelite_handle(TreeliteModelHandle tl_handle, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:499
auto import_from_treelite_model(treelite::Model const &tl_model, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:449
uint32_t index_type
Definition: index_type.hpp:20
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:431
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:187
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:176
Definition: dbscan.hpp:29
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: treelite_importer.hpp:42
element_op element
Definition: treelite_importer.hpp:43
row_op row
Definition: treelite_importer.hpp:44
double constant
Definition: treelite_importer.hpp:45
Definition: forest_model.hpp:37
Definition: exceptions.hpp:35
Definition: node.hpp:92
HOST DEVICE constexpr auto is_categorical() const
Definition: node.hpp:165
HOST DEVICE constexpr auto is_leaf() const
Definition: node.hpp:155
Definition: treelite_importer.hpp:55
auto uses_integer_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:229
auto get_postproc_params(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:171
auto get_num_feature(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:113
auto get_max_num_categories(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:118
auto uses_double_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:217
auto get_bias(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:166
static constexpr auto const traversal_order
Definition: treelite_importer.hpp:56
auto get_num_leaf_vector_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:138
auto import_to_specific_variant(index_type target_variant_index, treelite::Model const &tl_model, index_type num_class, index_type num_feature, index_type max_num_categories, std::vector< index_type > const &offsets, index_type align_bytes=index_type{}, raft_proto::device_type mem_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:246
auto uses_double_thresholds(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:206
auto num_trees(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:92
auto get_num_categorical_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:128
auto get_tree_sizes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:100
auto get_offsets(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:76
auto get_node_count(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:69
auto get_num_class(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:108
auto get_average_factor(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:150
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23