20 #include <treelite/tree.h>
24 #include <type_traits>
33 bool contains_degenerate =
38 if (contains_degenerate) {
40 auto modified_model = treelite::ConcatenateModelObjects({&tl_model});
42 [](
auto&& concrete_tl_model) {
43 using model_t = std::remove_const_t<std::remove_reference_t<decltype(concrete_tl_model)>>;
45 treelite::Tree<typename model_t::threshold_type, typename model_t::leaf_output_type>;
46 auto modified_trees = std::vector<tree_t>{};
48 for (tree_t& tree : concrete_tl_model.trees) {
49 if (tree.IsLeaf(root_id)) {
51 tree.HasDataCount(root_id) ? tree.DataCount(root_id) : std::uint64_t{};
52 auto new_tree = tree_t{};
54 const auto root_id = new_tree.AllocNode();
55 const auto cleft_id = new_tree.AllocNode();
56 const auto cright_id = new_tree.AllocNode();
57 new_tree.SetChildren(root_id, cleft_id, cright_id);
58 new_tree.SetNumericalTest(
59 root_id,
int{},
typename model_t::threshold_type{},
true, treelite::Operator::kLE);
60 if (tree.HasLeafVector(root_id)) {
61 const auto leaf_vector = tree.LeafVector(root_id);
62 new_tree.SetLeafVector(cleft_id, leaf_vector);
63 new_tree.SetLeafVector(cright_id, leaf_vector);
65 const auto leaf_value = tree.LeafValue(root_id);
66 new_tree.SetLeaf(cleft_id, leaf_value);
67 new_tree.SetLeaf(cright_id, leaf_value);
69 new_tree.SetDataCount(root_id, inst_cnt);
70 new_tree.SetDataCount(cleft_id, inst_cnt);
71 new_tree.SetDataCount(cright_id, std::uint64_t{});
72 modified_trees.push_back(std::move(new_tree));
74 modified_trees.push_back(std::move(tree));
77 concrete_tl_model.trees = std::move(modified_trees);
79 modified_model->variant_);
80 return modified_model;
82 return std::unique_ptr<treelite::Model>();
Definition: decision_forest.hpp:358
std::unique_ptr< treelite::Model > convert_degenerate_trees(treelite::Model const &tl_model)
Definition: degenerate_trees.hpp:31
int TREELITE_NODE_ID_T
Definition: treelite.hpp:30
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:187