degenerate_trees.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
8 
9 #include <treelite/tree.h>
10 
11 #include <cstdint>
12 #include <memory>
13 #include <type_traits>
14 
15 namespace ML::fil::detail {
16 
17 // This function returns a modified copy of a given Treelite model if it contains
18 // at least one degenerate tree (a single root node with no child).
19 // If the model contains no degenerate tree, then the function returns nullptr.
20 std::unique_ptr<treelite::Model> convert_degenerate_trees(treelite::Model const& tl_model)
21 {
22  bool contains_degenerate =
23  ML::forest::tree_accumulate(tl_model, false, [](auto&& contains, auto&& tree) {
24  return contains || tree.IsLeaf(ML::forest::TREELITE_NODE_ID_T{});
25  });
26 
27  if (contains_degenerate) {
28  // Make a copy of the Treelite model, and then update the trees in-place
29  auto modified_model = treelite::ConcatenateModelObjects({&tl_model});
30  std::visit(
31  [](auto&& concrete_tl_model) {
32  using model_t = std::remove_const_t<std::remove_reference_t<decltype(concrete_tl_model)>>;
33  using tree_t =
34  treelite::Tree<typename model_t::threshold_type, typename model_t::leaf_output_type>;
35  auto modified_trees = std::vector<tree_t>{};
36  const auto root_id = ML::forest::TREELITE_NODE_ID_T{};
37  for (tree_t& tree : concrete_tl_model.trees) {
38  if (tree.IsLeaf(root_id)) {
39  const auto inst_cnt =
40  tree.HasDataCount(root_id) ? tree.DataCount(root_id) : std::uint64_t{};
41  auto new_tree = tree_t{};
42  new_tree.Init();
43  const auto root_id = new_tree.AllocNode();
44  const auto cleft_id = new_tree.AllocNode();
45  const auto cright_id = new_tree.AllocNode();
46  new_tree.SetChildren(root_id, cleft_id, cright_id);
47  new_tree.SetNumericalTest(
48  root_id, int{}, typename model_t::threshold_type{}, true, treelite::Operator::kLE);
49  if (tree.HasLeafVector(root_id)) {
50  const auto leaf_vector = tree.LeafVector(root_id);
51  new_tree.SetLeafVector(cleft_id, leaf_vector);
52  new_tree.SetLeafVector(cright_id, leaf_vector);
53  } else {
54  const auto leaf_value = tree.LeafValue(root_id);
55  new_tree.SetLeaf(cleft_id, leaf_value);
56  new_tree.SetLeaf(cright_id, leaf_value);
57  }
58  new_tree.SetDataCount(root_id, inst_cnt);
59  new_tree.SetDataCount(cleft_id, inst_cnt);
60  new_tree.SetDataCount(cright_id, std::uint64_t{});
61  modified_trees.push_back(std::move(new_tree));
62  } else {
63  modified_trees.push_back(std::move(tree));
64  }
65  }
66  concrete_tl_model.trees = std::move(modified_trees);
67  },
68  modified_model->variant_);
69  return modified_model;
70  } else {
71  return std::unique_ptr<treelite::Model>();
72  }
73 }
74 
75 } // namespace ML::fil::detail
Definition: decision_forest.hpp:352
std::unique_ptr< treelite::Model > convert_degenerate_trees(treelite::Model const &tl_model)
Definition: degenerate_trees.hpp:20
int TREELITE_NODE_ID_T
Definition: treelite.hpp:19
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:183