20 #include <treelite/tree.h> 
   24 #include <type_traits> 
   33   bool contains_degenerate =
 
   38   if (contains_degenerate) {
 
   40     auto modified_model = treelite::ConcatenateModelObjects({&tl_model});
 
   42       [](
auto&& concrete_tl_model) {
 
   43         using model_t = std::remove_const_t<std::remove_reference_t<decltype(concrete_tl_model)>>;
 
   45           treelite::Tree<typename model_t::threshold_type, typename model_t::leaf_output_type>;
 
   46         auto modified_trees = std::vector<tree_t>{};
 
   48         for (tree_t& tree : concrete_tl_model.trees) {
 
   49           if (tree.IsLeaf(root_id)) {
 
   51               tree.HasDataCount(root_id) ? tree.DataCount(root_id) : std::uint64_t{};
 
   52             auto new_tree = tree_t{};
 
   54             const auto root_id   = new_tree.AllocNode();
 
   55             const auto cleft_id  = new_tree.AllocNode();
 
   56             const auto cright_id = new_tree.AllocNode();
 
   57             new_tree.SetChildren(root_id, cleft_id, cright_id);
 
   58             new_tree.SetNumericalTest(
 
   59               root_id, 
int{}, 
typename model_t::threshold_type{}, 
true, treelite::Operator::kLE);
 
   60             if (tree.HasLeafVector(root_id)) {
 
   61               const auto leaf_vector = tree.LeafVector(root_id);
 
   62               new_tree.SetLeafVector(cleft_id, leaf_vector);
 
   63               new_tree.SetLeafVector(cright_id, leaf_vector);
 
   65               const auto leaf_value = tree.LeafValue(root_id);
 
   66               new_tree.SetLeaf(cleft_id, leaf_value);
 
   67               new_tree.SetLeaf(cright_id, leaf_value);
 
   69             new_tree.SetDataCount(root_id, inst_cnt);
 
   70             new_tree.SetDataCount(cleft_id, inst_cnt);
 
   71             new_tree.SetDataCount(cright_id, std::uint64_t{});
 
   72             modified_trees.push_back(std::move(new_tree));
 
   74             modified_trees.push_back(std::move(tree));
 
   77         concrete_tl_model.trees = std::move(modified_trees);
 
   79       modified_model->variant_);
 
   80     return modified_model;
 
   82     return std::unique_ptr<treelite::Model>();
 
Definition: decision_forest.hpp:358
 
std::unique_ptr< treelite::Model > convert_degenerate_trees(treelite::Model const &tl_model)
Definition: degenerate_trees.hpp:31
 
int TREELITE_NODE_ID_T
Definition: treelite.hpp:30
 
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:187