treelite_importer.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 #include <cuml/fil/constants.hpp>
22 #include <cuml/fil/exceptions.hpp>
25 #include <cuml/fil/tree_layout.hpp>
27 
28 #include <raft/core/error.hpp>
29 
30 #include <treelite/c_api.h>
31 #include <treelite/enum/task_type.h>
32 #include <treelite/enum/tree_node_type.h>
33 #include <treelite/enum/typeinfo.h>
34 #include <treelite/tree.h>
35 
36 #include <cmath>
37 #include <variant>
38 
39 namespace ML {
40 namespace fil {
41 
42 namespace detail {
43 
47  double constant = 1.0;
48 };
49 } // namespace detail
50 
56 template <tree_layout layout>
58  auto static constexpr const traversal_order = []() constexpr {
59  if constexpr (layout == tree_layout::depth_first) {
61  } else if constexpr (layout == tree_layout::breadth_first) {
63  } else if constexpr (layout == tree_layout::layered_children_together) {
65  } else {
66  static_assert(layout == tree_layout::depth_first,
67  "Layout not yet implemented in treelite importer for FIL");
68  }
69  }();
70 
71  auto get_node_count(treelite::Model const& tl_model)
72  {
74  tl_model, index_type{}, [](auto&& count, auto&& tree) { return count + tree.num_nodes; });
75  }
76 
77  /* Return vector of offsets between each node and its most distant child */
78  auto get_offsets(treelite::Model const& tl_model)
79  {
80  auto node_count = get_node_count(tl_model);
81  auto result = std::vector<index_type>(node_count);
82  auto parent_indexes = std::vector<index_type>{};
83  parent_indexes.reserve(node_count);
84  ML::forest::node_transform<traversal_order>(
85  tl_model,
86  std::back_inserter(parent_indexes),
87  [](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) { return parent_index; });
88  for (auto i = std::size_t{}; i < node_count; ++i) {
89  result[parent_indexes[i]] = i - parent_indexes[i];
90  }
91  return result;
92  }
93 
94  auto num_trees(treelite::Model const& tl_model)
95  {
96  auto result = index_type{};
97  std::visit([&result](auto&& concrete_tl_model) { result = concrete_tl_model.trees.size(); },
98  tl_model.variant_);
99  return result;
100  }
101 
102  auto get_tree_sizes(treelite::Model const& tl_model)
103  {
104  auto result = std::vector<index_type>{};
106  tl_model, std::back_inserter(result), [](auto&& tree) { return tree.num_nodes; });
107  return result;
108  }
109 
110  auto get_num_class(treelite::Model const& tl_model)
111  {
112  return static_cast<index_type>(tl_model.num_class[0]);
113  }
114 
115  auto get_num_feature(treelite::Model const& tl_model)
116  {
117  return static_cast<index_type>(tl_model.num_feature);
118  }
119 
120  auto get_max_num_categories(treelite::Model const& tl_model)
121  {
122  return ML::forest::node_accumulate<traversal_order>(
123  tl_model,
124  index_type{},
125  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
126  return std::max(cur_accum, static_cast<index_type>(node.max_num_categories()));
127  });
128  }
129 
130  auto get_num_categorical_nodes(treelite::Model const& tl_model)
131  {
132  return ML::forest::node_accumulate<traversal_order>(
133  tl_model,
134  index_type{},
135  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
136  return cur_accum + static_cast<index_type>(node.is_categorical());
137  });
138  }
139 
140  auto get_num_leaf_vector_nodes(treelite::Model const& tl_model)
141  {
142  return ML::forest::node_accumulate<traversal_order>(
143  tl_model,
144  index_type{},
145  [](auto&& cur_accum, auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
146  auto accum = cur_accum;
147  if (node.is_leaf() && node.get_output().size() > 1) { ++accum; }
148  return accum;
149  });
150  }
151 
152  auto get_average_factor(treelite::Model const& tl_model)
153  {
154  auto result = double{};
155  if (tl_model.average_tree_output) {
156  if (tl_model.task_type == treelite::TaskType::kMultiClf &&
157  tl_model.leaf_vector_shape[1] == 1) { // grove-per-class
158  result = num_trees(tl_model) / tl_model.num_class[0];
159  } else {
160  result = num_trees(tl_model);
161  }
162  } else {
163  result = 1.0;
164  }
165  return result;
166  }
167 
168  auto get_bias(treelite::Model const& tl_model)
169  {
170  return static_cast<double>(tl_model.base_scores[0]);
171  }
172 
173  auto get_postproc_params(treelite::Model const& tl_model)
174  {
175  auto result = detail::postproc_params_t{};
176  auto tl_pred_transform = tl_model.postprocessor;
177  if (tl_pred_transform == std::string{"identity"} ||
178  tl_pred_transform == std::string{"identity_multiclass"}) {
179  result.element = element_op::disable;
180  result.row = row_op::disable;
181  } else if (tl_pred_transform == std::string{"signed_square"}) {
182  result.element = element_op::signed_square;
183  } else if (tl_pred_transform == std::string{"hinge"}) {
184  result.element = element_op::hinge;
185  } else if (tl_pred_transform == std::string{"sigmoid"}) {
186  result.constant = tl_model.sigmoid_alpha;
187  result.element = element_op::sigmoid;
188  } else if (tl_pred_transform == std::string{"exponential"}) {
189  result.element = element_op::exponential;
190  } else if (tl_pred_transform == std::string{"exponential_standard_ratio"}) {
191  result.constant = -tl_model.ratio_c / std::log(2);
192  result.element = element_op::exponential;
193  } else if (tl_pred_transform == std::string{"logarithm_one_plus_exp"}) {
194  result.element = element_op::logarithm_one_plus_exp;
195  } else if (tl_pred_transform == std::string{"max_index"}) {
196  result.row = row_op::max_index;
197  } else if (tl_pred_transform == std::string{"softmax"}) {
198  result.row = row_op::softmax;
199  } else if (tl_pred_transform == std::string{"multiclass_ova"}) {
200  result.constant = tl_model.sigmoid_alpha;
201  result.element = element_op::sigmoid;
202  } else {
203  throw model_import_error{"Unrecognized Treelite pred_transform string"};
204  }
205  return result;
206  }
207 
208  auto uses_double_thresholds(treelite::Model const& tl_model)
209  {
210  auto result = false;
211  switch (tl_model.GetThresholdType()) {
212  case treelite::TypeInfo::kFloat64: result = true; break;
213  case treelite::TypeInfo::kFloat32: result = false; break;
214  default: throw model_import_error("Unrecognized Treelite threshold type");
215  }
216  return result;
217  }
218 
219  auto uses_double_outputs(treelite::Model const& tl_model)
220  {
221  auto result = false;
222  switch (tl_model.GetThresholdType()) {
223  case treelite::TypeInfo::kFloat64: result = true; break;
224  case treelite::TypeInfo::kFloat32: result = false; break;
225  case treelite::TypeInfo::kUInt32: result = false; break;
226  default: throw model_import_error("Unrecognized Treelite threshold type");
227  }
228  return result;
229  }
230 
231  auto uses_integer_outputs(treelite::Model const& tl_model)
232  {
233  auto result = false;
234  switch (tl_model.GetThresholdType()) {
235  case treelite::TypeInfo::kFloat64: result = false; break;
236  case treelite::TypeInfo::kFloat32: result = false; break;
237  case treelite::TypeInfo::kUInt32: result = true; break;
238  default: throw model_import_error("Unrecognized Treelite threshold type");
239  }
240  return result;
241  }
242 
247  template <index_type variant_index>
248  auto import_to_specific_variant(index_type target_variant_index,
249  treelite::Model const& tl_model,
250  index_type num_class,
251  index_type num_feature,
252  index_type max_num_categories,
253  std::vector<index_type> const& offsets,
254  index_type align_bytes = index_type{},
256  int device = 0,
258  {
259  auto result = decision_forest_variant{};
260  if constexpr (variant_index != std::variant_size_v<decision_forest_variant>) {
261  if (variant_index == target_variant_index) {
262  using forest_model_t = std::variant_alternative_t<variant_index, decision_forest_variant>;
264  // Cannot align whole trees with layered traversal order, since trees
265  // are mingled together
266  align_bytes = index_type{};
267  }
268  auto builder =
269  detail::decision_forest_builder<forest_model_t>(max_num_categories, align_bytes);
270  auto node_index = index_type{};
271  ML::forest::node_for_each<traversal_order>(
272  tl_model,
273  [&builder, &offsets, &node_index](
274  auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
275  if (node.is_leaf()) {
276  auto output = node.get_output();
277  builder.set_output_size(output.size());
278  if (output.size() > index_type{1}) {
279  builder.add_leaf_vector_node(
280  std::begin(output), std::end(output), node.get_treelite_id(), depth);
281  } else {
282  builder.add_node(
283  typename forest_model_t::io_type(output[0]), node.get_treelite_id(), depth, true);
284  }
285  } else {
286  if (node.is_categorical()) {
287  auto categories = node.get_categories();
288  builder.add_categorical_node(std::begin(categories),
289  std::end(categories),
290  node.get_treelite_id(),
291  depth,
292  node.default_distant(),
293  node.get_feature(),
294  offsets[node_index]);
295  } else {
296  builder.add_node(typename forest_model_t::threshold_type(node.threshold()),
297  node.get_treelite_id(),
298  depth,
299  false,
300  node.default_distant(),
301  false,
302  node.get_feature(),
303  offsets[node_index],
304  node.is_inclusive());
305  }
306  }
307  ++node_index;
308  });
309 
310  builder.set_average_factor(get_average_factor(tl_model));
311  builder.set_bias(get_bias(tl_model));
312  auto postproc_params = get_postproc_params(tl_model);
313  builder.set_element_postproc(postproc_params.element);
314  builder.set_row_postproc(postproc_params.row);
315  builder.set_postproc_constant(postproc_params.constant);
316 
317  result.template emplace<variant_index>(
318  builder.get_decision_forest(num_feature, num_class, mem_type, device, stream));
319  } else {
320  result = import_to_specific_variant<variant_index + 1>(target_variant_index,
321  tl_model,
322  num_class,
323  num_feature,
324  max_num_categories,
325  offsets,
326  align_bytes,
327  mem_type,
328  device,
329  stream);
330  }
331  }
332  return result;
333  }
334 
357  forest_model import(treelite::Model const& tl_model,
358  index_type align_bytes = index_type{},
359  std::optional<bool> use_double_precision = std::nullopt,
361  int device = 0,
363  {
364  // Handle degenerate trees (a single root node with no child)
365  if (auto processed_tl_model = detail::convert_degenerate_trees(tl_model); processed_tl_model) {
366  return import(
367  *processed_tl_model.get(), align_bytes, use_double_precision, dev_type, device, stream);
368  }
369 
370  ASSERT(tl_model.num_target == 1, "FIL does not support multi-target model");
371  // Check tree annotation (assignment)
372  if (tl_model.task_type == treelite::TaskType::kMultiClf) {
373  // Must be either vector leaf or grove-per-class
374  if (tl_model.leaf_vector_shape[1] > 1) { // vector-leaf
375  ASSERT(tl_model.leaf_vector_shape[1] == int(tl_model.num_class[0]),
376  "Vector leaf must be equal to num_class = %d",
377  tl_model.num_class[0]);
378  auto tree_count = num_trees(tl_model);
379  for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
380  ASSERT(tl_model.class_id[tree_id] == -1, "Tree %d has invalid class assignment", tree_id);
381  }
382  } else { // grove-per-class
383  auto tree_count = num_trees(tl_model);
384  for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
385  ASSERT(tl_model.class_id[tree_id] == int(tree_id % tl_model.num_class[0]),
386  "Tree %d has invalid class assignment",
387  tree_id);
388  }
389  }
390  }
391  // Check base_scores
392  for (std::int32_t class_id = 1; class_id < tl_model.num_class[0]; ++class_id) {
393  ASSERT(tl_model.base_scores[0] == tl_model.base_scores[class_id],
394  "base_scores must be identical for all classes");
395  }
396 
397  auto result = decision_forest_variant{};
398  auto num_feature = get_num_feature(tl_model);
399  auto max_num_categories = get_max_num_categories(tl_model);
400  auto num_categorical_nodes = get_num_categorical_nodes(tl_model);
401  auto num_leaf_vector_nodes = get_num_leaf_vector_nodes(tl_model);
402  auto use_double_thresholds = use_double_precision.value_or(uses_double_thresholds(tl_model));
403 
404  auto offsets = get_offsets(tl_model);
405  auto max_offset = *std::max_element(std::begin(offsets), std::end(offsets));
406 
407  auto variant_index = get_forest_variant_index(use_double_thresholds,
408  max_offset,
409  num_feature,
410  num_categorical_nodes,
411  max_num_categories,
412  num_leaf_vector_nodes,
413  layout);
414  auto num_class = get_num_class(tl_model);
415  return forest_model{import_to_specific_variant<index_type{}>(variant_index,
416  tl_model,
417  num_class,
418  num_feature,
419  max_num_categories,
420  offsets,
421  align_bytes,
422  dev_type,
423  device,
424  stream)};
425  }
426 };
427 
451 auto import_from_treelite_model(treelite::Model const& tl_model,
452  tree_layout layout = preferred_tree_layout,
453  index_type align_bytes = index_type{},
454  std::optional<bool> use_double_precision = std::nullopt,
456  int device = 0,
458 {
459  auto result = forest_model{};
460  switch (layout) {
461  case tree_layout::depth_first:
462  result = treelite_importer<tree_layout::depth_first>{}.import(
463  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
464  break;
465  case tree_layout::breadth_first:
466  result = treelite_importer<tree_layout::breadth_first>{}.import(
467  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
468  break;
469  case tree_layout::layered_children_together:
470  result = treelite_importer<tree_layout::layered_children_together>{}.import(
471  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
472  break;
473  }
474  return result;
475 }
476 
502  tree_layout layout = preferred_tree_layout,
503  index_type align_bytes = index_type{},
504  std::optional<bool> use_double_precision = std::nullopt,
506  int device = 0,
508 {
509  return import_from_treelite_model(*static_cast<treelite::Model*>(tl_handle),
510  layout,
511  align_bytes,
512  use_double_precision,
513  dev_type,
514  device,
515  stream);
516 }
517 
518 } // namespace fil
519 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
std::unique_ptr< treelite::Model > convert_degenerate_trees(treelite::Model const &tl_model)
Definition: degenerate_trees.hpp:31
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:451
tree_layout
Definition: tree_layout.hpp:19
row_op
Definition: postproc_ops.hpp:21
element_op
Definition: postproc_ops.hpp:28
auto import_from_treelite_handle(TreeliteModelHandle tl_handle, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:501
auto import_from_treelite_model(treelite::Model const &tl_model, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:451
uint32_t index_type
Definition: index_type.hpp:20
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:431
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:187
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:176
Definition: dbscan.hpp:29
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: treelite_importer.hpp:44
element_op element
Definition: treelite_importer.hpp:45
row_op row
Definition: treelite_importer.hpp:46
double constant
Definition: treelite_importer.hpp:47
Definition: forest_model.hpp:40
Definition: exceptions.hpp:35
Definition: node.hpp:92
HOST DEVICE constexpr auto is_categorical() const
Definition: node.hpp:165
HOST DEVICE constexpr auto is_leaf() const
Definition: node.hpp:155
Definition: treelite_importer.hpp:57
auto uses_integer_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:231
auto get_postproc_params(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:173
auto get_num_feature(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:115
auto get_max_num_categories(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:120
auto uses_double_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:219
auto get_bias(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:168
static constexpr auto const traversal_order
Definition: treelite_importer.hpp:58
auto get_num_leaf_vector_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:140
auto import_to_specific_variant(index_type target_variant_index, treelite::Model const &tl_model, index_type num_class, index_type num_feature, index_type max_num_categories, std::vector< index_type > const &offsets, index_type align_bytes=index_type{}, raft_proto::device_type mem_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:248
auto uses_double_thresholds(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:208
auto num_trees(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:94
auto get_num_categorical_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:130
auto get_tree_sizes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:102
auto get_offsets(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:78
auto get_node_count(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:71
auto get_num_class(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:110
auto get_average_factor(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:152
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23