treelite_importer.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
25 
26 #include <treelite/c_api.h>
27 #include <treelite/enum/task_type.h>
28 #include <treelite/enum/tree_node_type.h>
29 #include <treelite/enum/typeinfo.h>
30 #include <treelite/tree.h>
31 
32 #include <cmath>
33 #include <cstddef>
34 #include <queue>
35 #include <stack>
36 #include <variant>
37 
38 namespace ML {
39 namespace experimental {
40 namespace fil {
41 
42 namespace detail {
45 template <tree_layout layout, typename T>
48  std::conditional_t<layout == tree_layout::depth_first, std::stack<T>, std::queue<T>>;
49  void add(T const& val) { data_.push(val); }
50  void add(T const& hot, T const& distant)
51  {
52  if constexpr (layout == tree_layout::depth_first) {
53  data_.push(distant);
54  data_.push(hot);
55  } else {
56  data_.push(hot);
57  data_.push(distant);
58  }
59  }
60  auto next()
61  {
62  if constexpr (std::is_same_v<backing_container_t, std::stack<T>>) {
63  auto result = data_.top();
64  data_.pop();
65  return result;
66  } else {
67  auto result = data_.front();
68  data_.pop();
69  return result;
70  }
71  }
72  auto peek()
73  {
74  if constexpr (std::is_same_v<backing_container_t, std::stack<T>>) {
75  return data_.top();
76  } else {
77  return data_.front();
78  }
79  }
80  [[nodiscard]] auto empty() { return data_.empty(); }
81  auto size() { return data_.size(); }
82 
83  private:
84  backing_container_t data_;
85 };
86 
90  double constant = 1.0;
91 };
92 } // namespace detail
93 
99 template <tree_layout layout>
101  template <typename tl_threshold_t, typename tl_output_t>
102  struct treelite_node {
103  treelite::Tree<tl_threshold_t, tl_output_t> const& tree;
104  int node_id;
107 
108  auto is_leaf() { return tree.IsLeaf(node_id); }
109 
110  auto get_output()
111  {
112  auto result = std::vector<tl_output_t>{};
113  if (tree.HasLeafVector(node_id)) {
114  result = tree.LeafVector(node_id);
115  } else {
116  result.push_back(tree.LeafValue(node_id));
117  }
118  return result;
119  }
120 
121  auto get_categories() { return tree.CategoryList(node_id); }
122 
123  auto get_feature() { return tree.SplitIndex(node_id); }
124 
126  {
127  return tree.NodeType(node_id) == treelite::TreeNodeType::kCategoricalTestNode;
128  }
129 
131  {
132  auto result = false;
133  auto default_child = tree.DefaultChild(node_id);
134  if (is_categorical()) {
135  if (tree.CategoryListRightChild(node_id)) {
136  result = (default_child == tree.RightChild(node_id));
137  } else {
138  result = (default_child == tree.LeftChild(node_id));
139  }
140  } else {
141  auto tl_operator = tree.ComparisonOp(node_id);
142  if (tl_operator == treelite::Operator::kLT || tl_operator == treelite::Operator::kLE) {
143  result = (default_child == tree.LeftChild(node_id));
144  } else {
145  result = (default_child == tree.RightChild(node_id));
146  }
147  }
148  return result;
149  }
150 
151  auto threshold() { return tree.Threshold(node_id); }
152 
153  auto categories()
154  {
155  auto result = decltype(tree.CategoryList(node_id)){};
156  if (is_categorical()) { result = tree.CategoryList(node_id); }
157  return result;
158  }
159 
161  {
162  auto tl_operator = tree.ComparisonOp(node_id);
163  return tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kLE;
164  }
165  };
166 
167  template <typename tl_threshold_t, typename tl_output_t, typename lambda_t>
168  void node_for_each(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree, lambda_t&& lambda)
169  {
170  using node_index_t = decltype(tl_tree.LeftChild(0));
172  to_be_visited.add(node_index_t{});
173 
174  auto parent_indices = detail::traversal_container<layout, index_type>{};
175  auto cur_index = index_type{};
176  parent_indices.add(cur_index);
177 
178  while (!to_be_visited.empty()) {
179  auto node_id = to_be_visited.next();
180  auto remaining_size = to_be_visited.size();
181 
183  tl_tree, node_id, parent_indices.next(), cur_index};
184  lambda(tl_node, node_id);
185 
186  if (!tl_tree.IsLeaf(node_id)) {
187  auto tl_left_id = tl_tree.LeftChild(node_id);
188  auto tl_right_id = tl_tree.RightChild(node_id);
189  auto tl_operator = tl_tree.ComparisonOp(node_id);
190  if (!tl_node.is_categorical()) {
191  if (tl_operator == treelite::Operator::kLT || tl_operator == treelite::Operator::kLE) {
192  to_be_visited.add(tl_right_id, tl_left_id);
193  } else if (tl_operator == treelite::Operator::kGT ||
194  tl_operator == treelite::Operator::kGE) {
195  to_be_visited.add(tl_left_id, tl_right_id);
196  } else {
197  throw model_import_error("Unrecognized Treelite operator");
198  }
199  } else {
200  if (tl_tree.CategoryListRightChild(node_id)) {
201  to_be_visited.add(tl_left_id, tl_right_id);
202  } else {
203  to_be_visited.add(tl_right_id, tl_left_id);
204  }
205  }
206  parent_indices.add(cur_index, cur_index);
207  }
208  ++cur_index;
209  }
210  }
211 
212  template <typename tl_threshold_t, typename tl_output_t, typename iter_t, typename lambda_t>
213  void node_transform(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree,
214  iter_t output_iter,
215  lambda_t&& lambda)
216  {
217  node_for_each(tl_tree, [&output_iter, &lambda](auto&& tl_node, int tl_node_id) {
218  *output_iter = lambda(tl_node);
219  ++output_iter;
220  });
221  }
222 
223  template <typename tl_threshold_t, typename tl_output_t, typename T, typename lambda_t>
224  auto node_accumulate(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree,
225  T init,
226  lambda_t&& lambda)
227  {
228  auto result = init;
229  node_for_each(tl_tree, [&result, &lambda](auto&& tl_node, int tl_node_id) {
230  result = lambda(result, tl_node);
231  });
232  return result;
233  }
234 
235  template <typename tl_threshold_t, typename tl_output_t>
236  auto get_nodes(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree)
237  {
238  auto result = std::vector<treelite_node<tl_threshold_t, tl_output_t>>{};
239  result.reserve(tl_tree.num_nodes);
240  node_transform(tl_tree, std::back_inserter(result), [](auto&& node) { return node; });
241  return result;
242  }
243 
244  template <typename tl_threshold_t, typename tl_output_t>
245  auto get_offsets(treelite::Tree<tl_threshold_t, tl_output_t> const& tl_tree)
246  {
247  auto result = std::vector<index_type>(tl_tree.num_nodes);
248  auto nodes = get_nodes(tl_tree);
249  for (auto i = index_type{}; i < nodes.size(); ++i) {
250  // Current index should always be greater than or equal to parent index.
251  // Later children will overwrite values set by earlier children, ensuring
252  // that most distant offset is used.
253  result[nodes[i].parent_index] = index_type{i - nodes[i].parent_index};
254  }
255 
256  return result;
257  }
258 
259  template <typename lambda_t>
260  void tree_for_each(treelite::Model const& tl_model, lambda_t&& lambda)
261  {
262  std::visit(
263  [&lambda](auto&& concrete_tl_model) {
264  std::for_each(
265  std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), lambda);
266  },
267  tl_model.variant_);
268  }
269 
270  template <typename iter_t, typename lambda_t>
271  void tree_transform(treelite::Model const& tl_model, iter_t output_iter, lambda_t&& lambda)
272  {
273  std::visit(
274  [&output_iter, &lambda](auto&& concrete_tl_model) {
275  std::transform(std::begin(concrete_tl_model.trees),
276  std::end(concrete_tl_model.trees),
277  output_iter,
278  lambda);
279  },
280  tl_model.variant_);
281  }
282 
283  template <typename T, typename lambda_t>
284  auto tree_accumulate(treelite::Model const& tl_model, T init, lambda_t&& lambda)
285  {
286  auto result = init;
287  tree_for_each(tl_model, [&result, &lambda](auto&& tree) { result = lambda(result, tree); });
288  return result;
289  }
290 
291  auto num_trees(treelite::Model const& tl_model)
292  {
293  auto result = index_type{};
294  std::visit([&result](auto&& concrete_tl_model) { result = concrete_tl_model.trees.size(); },
295  tl_model.variant_);
296  return result;
297  }
298 
299  auto get_offsets(treelite::Model const& tl_model)
300  {
301  auto result = std::vector<std::vector<index_type>>{};
302  result.reserve(num_trees(tl_model));
304  tl_model, std::back_inserter(result), [this](auto&& tree) { return get_offsets(tree); });
305  return result;
306  }
307 
308  auto get_tree_sizes(treelite::Model const& tl_model)
309  {
310  auto result = std::vector<index_type>{};
312  tl_model, std::back_inserter(result), [](auto&& tree) { return tree.num_nodes; });
313  return result;
314  }
315 
316  auto get_num_class(treelite::Model const& tl_model)
317  {
318  return static_cast<index_type>(tl_model.num_class[0]);
319  }
320 
321  auto get_num_feature(treelite::Model const& tl_model)
322  {
323  return static_cast<index_type>(tl_model.num_feature);
324  }
325 
326  auto get_max_num_categories(treelite::Model const& tl_model)
327  {
328  return tree_accumulate(tl_model, index_type{}, [this](auto&& accum, auto&& tree) {
329  return node_accumulate(tree, accum, [](auto&& cur_accum, auto&& tl_node) {
330  auto result = cur_accum;
331  for (auto&& cat : tl_node.categories()) {
332  result = (cat + 1 > result) ? cat + 1 : result;
333  }
334  return result;
335  });
336  });
337  }
338 
339  auto get_num_categorical_nodes(treelite::Model const& tl_model)
340  {
341  return tree_accumulate(tl_model, index_type{}, [this](auto&& accum, auto&& tree) {
342  return node_accumulate(tree, accum, [](auto&& cur_accum, auto&& tl_node) {
343  return cur_accum + tl_node.is_categorical();
344  });
345  });
346  }
347 
348  auto get_num_leaf_vector_nodes(treelite::Model const& tl_model)
349  {
350  return tree_accumulate(tl_model, index_type{}, [this](auto&& accum, auto&& tree) {
351  return node_accumulate(tree, accum, [](auto&& cur_accum, auto&& tl_node) {
352  return cur_accum + (tl_node.is_leaf() && tl_node.get_output().size() > 1);
353  });
354  });
355  }
356 
357  auto get_average_factor(treelite::Model const& tl_model)
358  {
359  auto result = double{};
360  if (tl_model.average_tree_output) {
361  if (tl_model.task_type == treelite::TaskType::kMultiClf &&
362  tl_model.leaf_vector_shape[1] == 1) { // grove-per-class
363  result = num_trees(tl_model) / tl_model.num_class[0];
364  } else {
365  result = num_trees(tl_model);
366  }
367  } else {
368  result = 1.0;
369  }
370  return result;
371  }
372 
373  auto get_bias(treelite::Model const& tl_model)
374  {
375  return static_cast<double>(tl_model.base_scores[0]);
376  }
377 
378  auto get_postproc_params(treelite::Model const& tl_model)
379  {
380  auto result = detail::postproc_params_t{};
381  auto tl_pred_transform = tl_model.postprocessor;
382  if (tl_pred_transform == std::string{"identity"} ||
383  tl_pred_transform == std::string{"identity_multiclass"}) {
384  result.element = element_op::disable;
385  result.row = row_op::disable;
386  } else if (tl_pred_transform == std::string{"signed_square"}) {
387  result.element = element_op::signed_square;
388  } else if (tl_pred_transform == std::string{"hinge"}) {
389  result.element = element_op::hinge;
390  } else if (tl_pred_transform == std::string{"sigmoid"}) {
391  result.constant = tl_model.sigmoid_alpha;
392  result.element = element_op::sigmoid;
393  } else if (tl_pred_transform == std::string{"exponential"}) {
394  result.element = element_op::exponential;
395  } else if (tl_pred_transform == std::string{"exponential_standard_ratio"}) {
396  result.constant = -tl_model.ratio_c / std::log(2);
397  result.element = element_op::exponential;
398  } else if (tl_pred_transform == std::string{"logarithm_one_plus_exp"}) {
399  result.element = element_op::logarithm_one_plus_exp;
400  } else if (tl_pred_transform == std::string{"max_index"}) {
401  result.row = row_op::max_index;
402  } else if (tl_pred_transform == std::string{"softmax"}) {
403  result.row = row_op::softmax;
404  } else if (tl_pred_transform == std::string{"multiclass_ova"}) {
405  result.constant = tl_model.sigmoid_alpha;
406  result.element = element_op::sigmoid;
407  } else {
408  throw model_import_error{"Unrecognized Treelite pred_transform string"};
409  }
410  return result;
411  }
412 
413  auto uses_double_thresholds(treelite::Model const& tl_model)
414  {
415  auto result = false;
416  switch (tl_model.GetThresholdType()) {
417  case treelite::TypeInfo::kFloat64: result = true; break;
418  case treelite::TypeInfo::kFloat32: result = false; break;
419  default: throw model_import_error("Unrecognized Treelite threshold type");
420  }
421  return result;
422  }
423 
424  auto uses_double_outputs(treelite::Model const& tl_model)
425  {
426  auto result = false;
427  switch (tl_model.GetThresholdType()) {
428  case treelite::TypeInfo::kFloat64: result = true; break;
429  case treelite::TypeInfo::kFloat32: result = false; break;
430  case treelite::TypeInfo::kUInt32: result = false; break;
431  default: throw model_import_error("Unrecognized Treelite threshold type");
432  }
433  return result;
434  }
435 
436  auto uses_integer_outputs(treelite::Model const& tl_model)
437  {
438  auto result = false;
439  switch (tl_model.GetThresholdType()) {
440  case treelite::TypeInfo::kFloat64: result = false; break;
441  case treelite::TypeInfo::kFloat32: result = false; break;
442  case treelite::TypeInfo::kUInt32: result = true; break;
443  default: throw model_import_error("Unrecognized Treelite threshold type");
444  }
445  return result;
446  }
447 
452  template <index_type variant_index>
453  auto import_to_specific_variant(index_type target_variant_index,
454  treelite::Model const& tl_model,
455  index_type num_class,
456  index_type num_feature,
457  index_type max_num_categories,
458  std::vector<std::vector<index_type>> const& offsets,
459  index_type align_bytes = index_type{},
461  int device = 0,
463  {
464  auto result = decision_forest_variant{};
465  if constexpr (variant_index != std::variant_size_v<decision_forest_variant>) {
466  if (variant_index == target_variant_index) {
467  using forest_model_t = std::variant_alternative_t<variant_index, decision_forest_variant>;
468  auto builder =
469  detail::decision_forest_builder<forest_model_t>(max_num_categories, align_bytes);
470  auto tree_count = num_trees(tl_model);
471  auto tree_index = index_type{};
472  tree_for_each(tl_model, [this, &builder, &tree_index, &offsets](auto&& tree) {
473  builder.start_new_tree();
474  auto node_index = index_type{};
476  tree, [&builder, &tree_index, &node_index, &offsets](auto&& node, int tl_node_id) {
477  if (node.is_leaf()) {
478  auto output = node.get_output();
479  builder.set_output_size(output.size());
480  if (output.size() > index_type{1}) {
481  builder.add_leaf_vector_node(std::begin(output), std::end(output), tl_node_id);
482  } else {
483  builder.add_node(typename forest_model_t::io_type(output[0]), tl_node_id, true);
484  }
485  } else {
486  if (node.is_categorical()) {
487  auto categories = node.get_categories();
488  builder.add_categorical_node(std::begin(categories),
489  std::end(categories),
490  tl_node_id,
491  node.default_distant(),
492  node.get_feature(),
493  offsets[tree_index][node_index]);
494  } else {
495  builder.add_node(typename forest_model_t::threshold_type(node.threshold()),
496  tl_node_id,
497  false,
498  node.default_distant(),
499  false,
500  node.get_feature(),
501  offsets[tree_index][node_index],
502  node.is_inclusive());
503  }
504  }
505  ++node_index;
506  });
507  ++tree_index;
508  });
509 
510  builder.set_average_factor(get_average_factor(tl_model));
511  builder.set_bias(get_bias(tl_model));
512  auto postproc_params = get_postproc_params(tl_model);
513  builder.set_element_postproc(postproc_params.element);
514  builder.set_row_postproc(postproc_params.row);
515  builder.set_postproc_constant(postproc_params.constant);
516 
517  result.template emplace<variant_index>(
518  builder.get_decision_forest(num_feature, num_class, mem_type, device, stream));
519  } else {
520  result = import_to_specific_variant<variant_index + 1>(target_variant_index,
521  tl_model,
522  num_class,
523  num_feature,
524  max_num_categories,
525  offsets,
526  align_bytes,
527  mem_type,
528  device,
529  stream);
530  }
531  }
532  return result;
533  }
534 
557  auto import(treelite::Model const& tl_model,
558  index_type align_bytes = index_type{},
559  std::optional<bool> use_double_precision = std::nullopt,
561  int device = 0,
563  {
564  ASSERT(tl_model.num_target == 1, "FIL does not support multi-target model");
565  // Check tree annotation (assignment)
566  if (tl_model.task_type == treelite::TaskType::kMultiClf) {
567  // Must be either vector leaf or grove-per-class
568  if (tl_model.leaf_vector_shape[1] > 1) { // vector-leaf
569  ASSERT(tl_model.leaf_vector_shape[1] == tl_model.num_class[0],
570  "Vector leaf must be equal to num_class = %d",
571  tl_model.num_class[0]);
572  auto tree_count = num_trees(tl_model);
573  for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
574  ASSERT(tl_model.class_id[tree_id] == -1, "Tree %d has invalid class assignment", tree_id);
575  }
576  } else { // grove-per-class
577  auto tree_count = num_trees(tl_model);
578  for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
579  ASSERT(tl_model.class_id[tree_id] == tree_id % tl_model.num_class[0],
580  "Tree %d has invalid class assignment",
581  tree_id);
582  }
583  }
584  }
585  // Check base_scores
586  for (std::int32_t class_id = 1; class_id < tl_model.num_class[0]; ++class_id) {
587  ASSERT(tl_model.base_scores[0] == tl_model.base_scores[class_id],
588  "base_scores must be identical for all classes");
589  }
590 
591  auto result = decision_forest_variant{};
592  auto num_feature = get_num_feature(tl_model);
593  auto max_num_categories = get_max_num_categories(tl_model);
594  auto num_categorical_nodes = get_num_categorical_nodes(tl_model);
595  auto num_leaf_vector_nodes = get_num_leaf_vector_nodes(tl_model);
596  auto use_double_thresholds = use_double_precision.value_or(uses_double_thresholds(tl_model));
597 
598  auto offsets = get_offsets(tl_model);
599  auto max_offset = std::accumulate(
600  std::begin(offsets),
601  std::end(offsets),
602  index_type{},
603  [&offsets](auto&& cur_max, auto&& tree_offsets) {
604  return std::max(cur_max,
605  *std::max_element(std::begin(tree_offsets), std::end(tree_offsets)));
606  });
607  auto tree_sizes = std::vector<index_type>{};
608  std::transform(std::begin(offsets),
609  std::end(offsets),
610  std::back_inserter(tree_sizes),
611  [](auto&& tree_offsets) { return tree_offsets.size(); });
612 
613  auto variant_index = get_forest_variant_index(use_double_thresholds,
614  max_offset,
615  num_feature,
616  num_categorical_nodes,
617  max_num_categories,
618  num_leaf_vector_nodes,
619  layout);
620  auto num_class = get_num_class(tl_model);
621  return forest_model{import_to_specific_variant<index_type{}>(variant_index,
622  tl_model,
623  num_class,
624  num_feature,
625  max_num_categories,
626  offsets,
627  align_bytes,
628  dev_type,
629  device,
630  stream)};
631  }
632 };
633 
657 auto import_from_treelite_model(treelite::Model const& tl_model,
658  tree_layout layout = preferred_tree_layout,
659  index_type align_bytes = index_type{},
660  std::optional<bool> use_double_precision = std::nullopt,
662  int device = 0,
664 {
665  auto result = forest_model{};
666  switch (layout) {
667  case tree_layout::depth_first:
668  result = treelite_importer<tree_layout::depth_first>{}.import(
669  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
670  break;
671  case tree_layout::breadth_first:
672  result = treelite_importer<tree_layout::breadth_first>{}.import(
673  tl_model, align_bytes, use_double_precision, dev_type, device, stream);
674  break;
675  }
676  return result;
677 }
678 
704  tree_layout layout = preferred_tree_layout,
705  index_type align_bytes = index_type{},
706  std::optional<bool> use_double_precision = std::nullopt,
708  int device = 0,
710 {
711  return import_from_treelite_model(*static_cast<treelite::Model*>(tl_handle),
712  layout,
713  align_bytes,
714  use_double_precision,
715  dev_type,
716  device,
717  stream);
718 }
719 
720 } // namespace fil
721 } // namespace experimental
722 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
tree_layout
Definition: tree_layout.hpp:20
element_op
Definition: postproc_ops.hpp:29
uint32_t index_type
Definition: index_type.hpp:21
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:436
auto import_from_treelite_handle(TreeliteModelHandle tl_handle, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:703
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:416
row_op
Definition: postproc_ops.hpp:22
auto import_from_treelite_model(treelite::Model const &tl_model, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:657
void transform(const raft::handle_t &handle, const KMeansParams &params, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
Definition: dbscan.hpp:30
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: treelite_importer.hpp:87
element_op element
Definition: treelite_importer.hpp:88
double constant
Definition: treelite_importer.hpp:90
row_op row
Definition: treelite_importer.hpp:89
Definition: treelite_importer.hpp:46
auto empty()
Definition: treelite_importer.hpp:80
auto next()
Definition: treelite_importer.hpp:60
auto peek()
Definition: treelite_importer.hpp:72
void add(T const &val)
Definition: treelite_importer.hpp:49
auto size()
Definition: treelite_importer.hpp:81
void add(T const &hot, T const &distant)
Definition: treelite_importer.hpp:50
std::conditional_t< layout==tree_layout::depth_first, std::stack< T >, std::queue< T > > backing_container_t
Definition: treelite_importer.hpp:48
Definition: exceptions.hpp:36
Definition: node.hpp:93
Definition: treelite_importer.hpp:102
auto get_feature()
Definition: treelite_importer.hpp:123
auto is_categorical()
Definition: treelite_importer.hpp:125
auto get_categories()
Definition: treelite_importer.hpp:121
auto threshold()
Definition: treelite_importer.hpp:151
index_type parent_index
Definition: treelite_importer.hpp:105
auto is_leaf()
Definition: treelite_importer.hpp:108
auto get_output()
Definition: treelite_importer.hpp:110
auto default_distant()
Definition: treelite_importer.hpp:130
index_type own_index
Definition: treelite_importer.hpp:106
auto is_inclusive()
Definition: treelite_importer.hpp:160
int node_id
Definition: treelite_importer.hpp:104
treelite::Tree< tl_threshold_t, tl_output_t > const & tree
Definition: treelite_importer.hpp:103
auto categories()
Definition: treelite_importer.hpp:153
Definition: treelite_importer.hpp:100
auto get_nodes(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree)
Definition: treelite_importer.hpp:236
void tree_transform(treelite::Model const &tl_model, iter_t output_iter, lambda_t &&lambda)
Definition: treelite_importer.hpp:271
auto node_accumulate(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, T init, lambda_t &&lambda)
Definition: treelite_importer.hpp:224
void node_transform(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, iter_t output_iter, lambda_t &&lambda)
Definition: treelite_importer.hpp:213
auto uses_double_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:424
auto get_num_feature(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:321
auto get_tree_sizes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:308
auto import_to_specific_variant(index_type target_variant_index, treelite::Model const &tl_model, index_type num_class, index_type num_feature, index_type max_num_categories, std::vector< std::vector< index_type >> const &offsets, index_type align_bytes=index_type{}, raft_proto::device_type mem_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:453
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite_importer.hpp:284
auto get_num_leaf_vector_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:348
void node_for_each(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, lambda_t &&lambda)
Definition: treelite_importer.hpp:168
auto get_num_class(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:316
auto get_postproc_params(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:378
auto get_offsets(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree)
Definition: treelite_importer.hpp:245
auto get_num_categorical_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:339
auto get_bias(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:373
auto uses_double_thresholds(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:413
auto uses_integer_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:436
auto get_average_factor(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:357
auto get_offsets(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:299
void tree_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite_importer.hpp:260
auto get_max_num_categories(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:326
auto num_trees(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:291
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23