9 #include <treelite/tree.h>
13 #include <type_traits>
22 bool contains_degenerate =
27 if (contains_degenerate) {
29 auto modified_model = treelite::ConcatenateModelObjects({&tl_model});
31 [](
auto&& concrete_tl_model) {
32 using model_t = std::remove_const_t<std::remove_reference_t<decltype(concrete_tl_model)>>;
34 treelite::Tree<typename model_t::threshold_type, typename model_t::leaf_output_type>;
35 auto modified_trees = std::vector<tree_t>{};
37 for (tree_t& tree : concrete_tl_model.trees) {
38 if (tree.IsLeaf(root_id)) {
40 tree.HasDataCount(root_id) ? tree.DataCount(root_id) : std::uint64_t{};
41 auto new_tree = tree_t{};
43 const auto root_id = new_tree.AllocNode();
44 const auto cleft_id = new_tree.AllocNode();
45 const auto cright_id = new_tree.AllocNode();
46 new_tree.SetChildren(root_id, cleft_id, cright_id);
47 new_tree.SetNumericalTest(
48 root_id,
int{},
typename model_t::threshold_type{},
true, treelite::Operator::kLE);
49 if (tree.HasLeafVector(root_id)) {
50 const auto leaf_vector = tree.LeafVector(root_id);
51 new_tree.SetLeafVector(cleft_id, leaf_vector);
52 new_tree.SetLeafVector(cright_id, leaf_vector);
54 const auto leaf_value = tree.LeafValue(root_id);
55 new_tree.SetLeaf(cleft_id, leaf_value);
56 new_tree.SetLeaf(cright_id, leaf_value);
58 new_tree.SetDataCount(root_id, inst_cnt);
59 new_tree.SetDataCount(cleft_id, inst_cnt);
60 new_tree.SetDataCount(cright_id, std::uint64_t{});
61 modified_trees.push_back(std::move(new_tree));
63 modified_trees.push_back(std::move(tree));
66 concrete_tl_model.trees = std::move(modified_trees);
68 modified_model->variant_);
69 return modified_model;
71 return std::unique_ptr<treelite::Model>();
Definition: decision_forest.hpp:352
std::unique_ptr< treelite::Model > convert_degenerate_trees(treelite::Model const &tl_model)
Definition: degenerate_trees.hpp:20
int TREELITE_NODE_ID_T
Definition: treelite.hpp:19
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:183