degenerate_trees.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
19 
20 #include <treelite/tree.h>
21 
22 #include <cstdint>
23 #include <memory>
24 #include <type_traits>
25 
27 
28 // This function returns a modified copy of a given Treelite model if it contains
29 // at least one degenerate tree (a single root node with no child).
30 // If the model contains no degenerate tree, then the function returns nullptr.
31 std::unique_ptr<treelite::Model> convert_degenerate_trees(treelite::Model const& tl_model)
32 {
33  bool contains_degenerate =
34  ML::experimental::forest::tree_accumulate(tl_model, false, [](auto&& contains, auto&& tree) {
35  return contains || tree.IsLeaf(ML::experimental::forest::TREELITE_NODE_ID_T{});
36  });
37 
38  if (contains_degenerate) {
39  // Make a copy of the Treelite model, and then update the trees in-place
40  auto modified_model = treelite::ConcatenateModelObjects({&tl_model});
41  std::visit(
42  [](auto&& concrete_tl_model) {
43  using model_t = std::remove_const_t<std::remove_reference_t<decltype(concrete_tl_model)>>;
44  using tree_t =
45  treelite::Tree<typename model_t::threshold_type, typename model_t::leaf_output_type>;
46  auto modified_trees = std::vector<tree_t>{};
47  const auto root_id = experimental::forest::TREELITE_NODE_ID_T{};
48  for (tree_t& tree : concrete_tl_model.trees) {
49  if (tree.IsLeaf(root_id)) {
50  const auto inst_cnt =
51  tree.HasDataCount(root_id) ? tree.DataCount(root_id) : std::uint64_t{};
52  auto new_tree = tree_t{};
53  new_tree.Init();
54  const auto root_id = new_tree.AllocNode();
55  const auto cleft_id = new_tree.AllocNode();
56  const auto cright_id = new_tree.AllocNode();
57  new_tree.SetChildren(root_id, cleft_id, cright_id);
58  new_tree.SetNumericalTest(
59  root_id, int{}, typename model_t::threshold_type{}, true, treelite::Operator::kLE);
60  if (tree.HasLeafVector(root_id)) {
61  const auto leaf_vector = tree.LeafVector(root_id);
62  new_tree.SetLeafVector(cleft_id, leaf_vector);
63  new_tree.SetLeafVector(cright_id, leaf_vector);
64  } else {
65  const auto leaf_value = tree.LeafValue(root_id);
66  new_tree.SetLeaf(cleft_id, leaf_value);
67  new_tree.SetLeaf(cright_id, leaf_value);
68  }
69  new_tree.SetDataCount(root_id, inst_cnt);
70  new_tree.SetDataCount(cleft_id, inst_cnt);
71  new_tree.SetDataCount(cright_id, std::uint64_t{});
72  modified_trees.push_back(std::move(new_tree));
73  } else {
74  modified_trees.push_back(std::move(tree));
75  }
76  }
77  concrete_tl_model.trees = std::move(modified_trees);
78  },
79  modified_model->variant_);
80  return modified_model;
81  } else {
82  return std::unique_ptr<treelite::Model>();
83  }
84 }
85 
86 } // namespace ML::experimental::fil::detail
Definition: decision_forest.hpp:359
std::unique_ptr< treelite::Model > convert_degenerate_trees(treelite::Model const &tl_model)
Definition: degenerate_trees.hpp:31
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:188
int TREELITE_NODE_ID_T
Definition: treelite.hpp:31