28 #include <treelite/c_api.h> 
   29 #include <treelite/enum/task_type.h> 
   30 #include <treelite/enum/tree_node_type.h> 
   31 #include <treelite/enum/typeinfo.h> 
   32 #include <treelite/tree.h> 
   54 template <tree_layout layout>
 
   65                     "Layout not yet implemented in treelite importer for FIL");
 
   72       tl_model, 
index_type{}, [](
auto&& count, 
auto&& tree) { 
return count + tree.num_nodes; });
 
   79     auto result         = std::vector<index_type>(node_count);
 
   80     auto parent_indexes = std::vector<index_type>{};
 
   81     parent_indexes.reserve(node_count);
 
   82     ML::forest::node_transform<traversal_order>(
 
   84       std::back_inserter(parent_indexes),
 
   85       [](
auto&& tree_id, 
auto&& 
node, 
auto&& depth, 
auto&& parent_index) { 
return parent_index; });
 
   86     for (
auto i = std::size_t{}; i < node_count; ++i) {
 
   87       result[parent_indexes[i]] = i - parent_indexes[i];
 
   95     std::visit([&result](
auto&& concrete_tl_model) { result = concrete_tl_model.trees.size(); },
 
  102     auto result = std::vector<index_type>{};
 
  104       tl_model, std::back_inserter(result), [](
auto&& tree) { 
return tree.num_nodes; });
 
  110     return static_cast<index_type>(tl_model.num_class[0]);
 
  115     return static_cast<index_type>(tl_model.num_feature);
 
  120     return ML::forest::node_accumulate<traversal_order>(
 
  123       [](
auto&& cur_accum, 
auto&& tree_id, 
auto&& 
node, 
auto&& depth, 
auto&& parent_index) {
 
  130     return ML::forest::node_accumulate<traversal_order>(
 
  133       [](
auto&& cur_accum, 
auto&& tree_id, 
auto&& 
node, 
auto&& depth, 
auto&& parent_index) {
 
  140     return ML::forest::node_accumulate<traversal_order>(
 
  143       [](
auto&& cur_accum, 
auto&& tree_id, 
auto&& 
node, 
auto&& depth, 
auto&& parent_index) {
 
  144         auto accum = cur_accum;
 
  152     auto result = 
double{};
 
  153     if (tl_model.average_tree_output) {
 
  154       if (tl_model.task_type == treelite::TaskType::kMultiClf &&
 
  155           tl_model.leaf_vector_shape[1] == 1) {  
 
  156         result = 
num_trees(tl_model) / tl_model.num_class[0];
 
  168     return static_cast<double>(tl_model.base_scores[0]);
 
  174     auto tl_pred_transform = tl_model.postprocessor;
 
  175     if (tl_pred_transform == std::string{
"identity"} ||
 
  176         tl_pred_transform == std::string{
"identity_multiclass"}) {
 
  179     } 
else if (tl_pred_transform == std::string{
"signed_square"}) {
 
  181     } 
else if (tl_pred_transform == std::string{
"hinge"}) {
 
  183     } 
else if (tl_pred_transform == std::string{
"sigmoid"}) {
 
  184       result.constant = tl_model.sigmoid_alpha;
 
  186     } 
else if (tl_pred_transform == std::string{
"exponential"}) {
 
  188     } 
else if (tl_pred_transform == std::string{
"exponential_standard_ratio"}) {
 
  189       result.constant = -tl_model.ratio_c / std::log(2);
 
  191     } 
else if (tl_pred_transform == std::string{
"logarithm_one_plus_exp"}) {
 
  193     } 
else if (tl_pred_transform == std::string{
"max_index"}) {
 
  195     } 
else if (tl_pred_transform == std::string{
"softmax"}) {
 
  197     } 
else if (tl_pred_transform == std::string{
"multiclass_ova"}) {
 
  198       result.constant = tl_model.sigmoid_alpha;
 
  209     switch (tl_model.GetThresholdType()) {
 
  210       case treelite::TypeInfo::kFloat64: result = 
true; 
break;
 
  211       case treelite::TypeInfo::kFloat32: result = 
false; 
break;
 
  220     switch (tl_model.GetThresholdType()) {
 
  221       case treelite::TypeInfo::kFloat64: result = 
true; 
break;
 
  222       case treelite::TypeInfo::kFloat32: result = 
false; 
break;
 
  223       case treelite::TypeInfo::kUInt32: result = 
false; 
break;
 
  232     switch (tl_model.GetThresholdType()) {
 
  233       case treelite::TypeInfo::kFloat64: result = 
false; 
break;
 
  234       case treelite::TypeInfo::kFloat32: result = 
false; 
break;
 
  235       case treelite::TypeInfo::kUInt32: result = 
true; 
break;
 
  245   template <index_type variant_index>
 
  247                                   treelite::Model 
const& tl_model,
 
  251                                   std::vector<index_type> 
const& offsets,
 
  258     if constexpr (variant_index != std::variant_size_v<decision_forest_variant>) {
 
  259       if (variant_index == target_variant_index) {
 
  260         using forest_model_t = std::variant_alternative_t<variant_index, decision_forest_variant>;
 
  267           detail::decision_forest_builder<forest_model_t>(max_num_categories, align_bytes);
 
  269         ML::forest::node_for_each<traversal_order>(
 
  271           [&builder, &offsets, &node_index](
 
  272             auto&& tree_id, 
auto&& node, 
auto&& depth, 
auto&& parent_index) {
 
  273             if (node.is_leaf()) {
 
  274               auto output = node.get_output();
 
  275               builder.set_output_size(output.size());
 
  276               if (output.size() > index_type{1}) {
 
  277                 builder.add_leaf_vector_node(
 
  278                   std::begin(output), std::end(output), node.get_treelite_id(), depth);
 
  281                   typename forest_model_t::io_type(output[0]), node.get_treelite_id(), depth, 
true);
 
  284               if (node.is_categorical()) {
 
  285                 auto categories = node.get_categories();
 
  286                 builder.add_categorical_node(std::begin(categories),
 
  287                                              std::end(categories),
 
  288                                              node.get_treelite_id(),
 
  290                                              node.default_distant(),
 
  292                                              offsets[node_index]);
 
  294                 builder.add_node(
typename forest_model_t::threshold_type(node.threshold()),
 
  295                                  node.get_treelite_id(),
 
  298                                  node.default_distant(),
 
  302                                  node.is_inclusive());
 
  309         builder.set_bias(
get_bias(tl_model));
 
  311         builder.set_element_postproc(postproc_params.element);
 
  312         builder.set_row_postproc(postproc_params.row);
 
  313         builder.set_postproc_constant(postproc_params.constant);
 
  315         result.template emplace<variant_index>(
 
  316           builder.get_decision_forest(num_feature, num_class, mem_type, device, stream));
 
  318         result = import_to_specific_variant<variant_index + 1>(target_variant_index,
 
  357                       std::optional<bool> use_double_precision = std::nullopt,
 
  365         *processed_tl_model.get(), align_bytes, use_double_precision, dev_type, device, stream);
 
  368     ASSERT(tl_model.num_target == 1, 
"FIL does not support multi-target model");
 
  370     if (tl_model.task_type == treelite::TaskType::kMultiClf) {
 
  372       if (tl_model.leaf_vector_shape[1] > 1) {  
 
  373         ASSERT(tl_model.leaf_vector_shape[1] == 
int(tl_model.num_class[0]),
 
  374                "Vector leaf must be equal to num_class = %d",
 
  375                tl_model.num_class[0]);
 
  376         auto tree_count = num_trees(tl_model);
 
  377         for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
 
  378           ASSERT(tl_model.class_id[tree_id] == -1, 
"Tree %d has invalid class assignment", tree_id);
 
  381         auto tree_count = num_trees(tl_model);
 
  382         for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
 
  383           ASSERT(tl_model.class_id[tree_id] == 
int(tree_id % tl_model.num_class[0]),
 
  384                  "Tree %d has invalid class assignment",
 
  390     for (std::int32_t class_id = 1; class_id < tl_model.num_class[0]; ++class_id) {
 
  391       ASSERT(tl_model.base_scores[0] == tl_model.base_scores[class_id],
 
  392              "base_scores must be identical for all classes");
 
  396     auto num_feature           = get_num_feature(tl_model);
 
  397     auto max_num_categories    = get_max_num_categories(tl_model);
 
  398     auto num_categorical_nodes = get_num_categorical_nodes(tl_model);
 
  399     auto num_leaf_vector_nodes = get_num_leaf_vector_nodes(tl_model);
 
  400     auto use_double_thresholds = use_double_precision.value_or(uses_double_thresholds(tl_model));
 
  402     auto offsets    = get_offsets(tl_model);
 
  403     auto max_offset = *std::max_element(std::begin(offsets), std::end(offsets));
 
  408                                                   num_categorical_nodes,
 
  410                                                   num_leaf_vector_nodes,
 
  412     auto num_class     = get_num_class(tl_model);
 
  413     return forest_model{import_to_specific_variant<
index_type{}>(variant_index,
 
  452                                 std::optional<bool> use_double_precision = std::nullopt,
 
  457   auto result = forest_model{};
 
  459     case tree_layout::depth_first:
 
  460       result = treelite_importer<tree_layout::depth_first>{}.import(
 
  461         tl_model, align_bytes, use_double_precision, dev_type, device, stream);
 
  463     case tree_layout::breadth_first:
 
  464       result = treelite_importer<tree_layout::breadth_first>{}.import(
 
  465         tl_model, align_bytes, use_double_precision, dev_type, device, stream);
 
  467     case tree_layout::layered_children_together:
 
  468       result = treelite_importer<tree_layout::layered_children_together>{}.import(
 
  469         tl_model, align_bytes, use_double_precision, dev_type, device, stream);
 
  502                                  std::optional<bool> use_double_precision = std::nullopt,
 
  510                                     use_double_precision,
 
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
 
std::unique_ptr< treelite::Model > convert_degenerate_trees(treelite::Model const &tl_model)
Definition: degenerate_trees.hpp:31
 
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:451
 
tree_layout
Definition: tree_layout.hpp:19
 
@ layered_children_together
 
row_op
Definition: postproc_ops.hpp:21
 
element_op
Definition: postproc_ops.hpp:28
 
auto import_from_treelite_handle(TreeliteModelHandle tl_handle, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:499
 
auto import_from_treelite_model(treelite::Model const &tl_model, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:449
 
uint32_t index_type
Definition: index_type.hpp:20
 
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:431
 
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:187
 
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:176
 
@ layered_children_together
 
Definition: dbscan.hpp:29
 
int cuda_stream
Definition: cuda_stream.hpp:25
 
device_type
Definition: device_type.hpp:18
 
Definition: treelite_importer.hpp:42
 
element_op element
Definition: treelite_importer.hpp:43
 
row_op row
Definition: treelite_importer.hpp:44
 
double constant
Definition: treelite_importer.hpp:45
 
Definition: forest_model.hpp:37
 
Definition: exceptions.hpp:35
 
HOST DEVICE constexpr auto is_categorical() const
Definition: node.hpp:165
 
HOST DEVICE constexpr auto is_leaf() const
Definition: node.hpp:155
 
Definition: treelite_importer.hpp:55
 
auto uses_integer_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:229
 
auto get_postproc_params(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:171
 
auto get_num_feature(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:113
 
auto get_max_num_categories(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:118
 
auto uses_double_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:217
 
auto get_bias(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:166
 
static constexpr auto const traversal_order
Definition: treelite_importer.hpp:56
 
auto get_num_leaf_vector_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:138
 
auto import_to_specific_variant(index_type target_variant_index, treelite::Model const &tl_model, index_type num_class, index_type num_feature, index_type max_num_categories, std::vector< index_type > const &offsets, index_type align_bytes=index_type{}, raft_proto::device_type mem_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:246
 
auto uses_double_thresholds(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:206
 
auto num_trees(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:92
 
auto get_num_categorical_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:128
 
auto get_tree_sizes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:100
 
auto get_offsets(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:76
 
auto get_node_count(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:69
 
auto get_num_class(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:108
 
auto get_average_factor(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:150
 
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23