28 #include <treelite/c_api.h>
29 #include <treelite/enum/task_type.h>
30 #include <treelite/enum/tree_node_type.h>
31 #include <treelite/enum/typeinfo.h>
32 #include <treelite/tree.h>
38 namespace experimental {
55 template <tree_layout layout>
66 "Layout not yet implemented in treelite importer for FIL");
73 tl_model,
index_type{}, [](
auto&& count,
auto&& tree) {
return count + tree.num_nodes; });
80 auto result = std::vector<index_type>(node_count);
81 auto parent_indexes = std::vector<index_type>{};
82 parent_indexes.reserve(node_count);
83 ML::experimental::forest::node_transform<traversal_order>(
85 std::back_inserter(parent_indexes),
86 [](
auto&& tree_id,
auto&&
node,
auto&& depth,
auto&& parent_index) {
return parent_index; });
87 for (
auto i = std::size_t{}; i < node_count; ++i) {
88 result[parent_indexes[i]] = i - parent_indexes[i];
96 std::visit([&result](
auto&& concrete_tl_model) { result = concrete_tl_model.trees.size(); },
103 auto result = std::vector<index_type>{};
105 tl_model, std::back_inserter(result), [](
auto&& tree) {
return tree.num_nodes; });
111 return static_cast<index_type>(tl_model.num_class[0]);
116 return static_cast<index_type>(tl_model.num_feature);
121 return ML::experimental::forest::node_accumulate<traversal_order>(
124 [](
auto&& cur_accum,
auto&& tree_id,
auto&&
node,
auto&& depth,
auto&& parent_index) {
131 return ML::experimental::forest::node_accumulate<traversal_order>(
134 [](
auto&& cur_accum,
auto&& tree_id,
auto&&
node,
auto&& depth,
auto&& parent_index) {
141 return ML::experimental::forest::node_accumulate<traversal_order>(
144 [](
auto&& cur_accum,
auto&& tree_id,
auto&&
node,
auto&& depth,
auto&& parent_index) {
145 auto accum = cur_accum;
153 auto result =
double{};
154 if (tl_model.average_tree_output) {
155 if (tl_model.task_type == treelite::TaskType::kMultiClf &&
156 tl_model.leaf_vector_shape[1] == 1) {
157 result =
num_trees(tl_model) / tl_model.num_class[0];
169 return static_cast<double>(tl_model.base_scores[0]);
175 auto tl_pred_transform = tl_model.postprocessor;
176 if (tl_pred_transform == std::string{
"identity"} ||
177 tl_pred_transform == std::string{
"identity_multiclass"}) {
180 }
else if (tl_pred_transform == std::string{
"signed_square"}) {
182 }
else if (tl_pred_transform == std::string{
"hinge"}) {
184 }
else if (tl_pred_transform == std::string{
"sigmoid"}) {
185 result.constant = tl_model.sigmoid_alpha;
187 }
else if (tl_pred_transform == std::string{
"exponential"}) {
189 }
else if (tl_pred_transform == std::string{
"exponential_standard_ratio"}) {
190 result.constant = -tl_model.ratio_c / std::log(2);
192 }
else if (tl_pred_transform == std::string{
"logarithm_one_plus_exp"}) {
194 }
else if (tl_pred_transform == std::string{
"max_index"}) {
196 }
else if (tl_pred_transform == std::string{
"softmax"}) {
198 }
else if (tl_pred_transform == std::string{
"multiclass_ova"}) {
199 result.constant = tl_model.sigmoid_alpha;
210 switch (tl_model.GetThresholdType()) {
211 case treelite::TypeInfo::kFloat64: result =
true;
break;
212 case treelite::TypeInfo::kFloat32: result =
false;
break;
221 switch (tl_model.GetThresholdType()) {
222 case treelite::TypeInfo::kFloat64: result =
true;
break;
223 case treelite::TypeInfo::kFloat32: result =
false;
break;
224 case treelite::TypeInfo::kUInt32: result =
false;
break;
233 switch (tl_model.GetThresholdType()) {
234 case treelite::TypeInfo::kFloat64: result =
false;
break;
235 case treelite::TypeInfo::kFloat32: result =
false;
break;
236 case treelite::TypeInfo::kUInt32: result =
true;
break;
246 template <index_type variant_index>
248 treelite::Model
const& tl_model,
252 std::vector<index_type>
const& offsets,
259 if constexpr (variant_index != std::variant_size_v<decision_forest_variant>) {
260 if (variant_index == target_variant_index) {
261 using forest_model_t = std::variant_alternative_t<variant_index, decision_forest_variant>;
269 detail::decision_forest_builder<forest_model_t>(max_num_categories, align_bytes);
271 ML::experimental::forest::node_for_each<traversal_order>(
273 [&builder, &offsets, &node_index](
274 auto&& tree_id,
auto&& node,
auto&& depth,
auto&& parent_index) {
275 if (node.is_leaf()) {
276 auto output = node.get_output();
277 builder.set_output_size(output.size());
278 if (output.size() > index_type{1}) {
279 builder.add_leaf_vector_node(
280 std::begin(output), std::end(output), node.get_treelite_id(), depth);
283 typename forest_model_t::io_type(output[0]), node.get_treelite_id(), depth,
true);
286 if (node.is_categorical()) {
287 auto categories = node.get_categories();
288 builder.add_categorical_node(std::begin(categories),
289 std::end(categories),
290 node.get_treelite_id(),
292 node.default_distant(),
294 offsets[node_index]);
296 builder.add_node(
typename forest_model_t::threshold_type(node.threshold()),
297 node.get_treelite_id(),
300 node.default_distant(),
304 node.is_inclusive());
311 builder.set_bias(
get_bias(tl_model));
313 builder.set_element_postproc(postproc_params.element);
314 builder.set_row_postproc(postproc_params.row);
315 builder.set_postproc_constant(postproc_params.constant);
317 result.template emplace<variant_index>(
318 builder.get_decision_forest(num_feature, num_class, mem_type, device, stream));
320 result = import_to_specific_variant<variant_index + 1>(target_variant_index,
359 std::optional<bool> use_double_precision = std::nullopt,
367 *processed_tl_model.get(), align_bytes, use_double_precision, dev_type, device, stream);
370 ASSERT(tl_model.num_target == 1,
"FIL does not support multi-target model");
372 if (tl_model.task_type == treelite::TaskType::kMultiClf) {
374 if (tl_model.leaf_vector_shape[1] > 1) {
375 ASSERT(tl_model.leaf_vector_shape[1] ==
int(tl_model.num_class[0]),
376 "Vector leaf must be equal to num_class = %d",
377 tl_model.num_class[0]);
378 auto tree_count = num_trees(tl_model);
379 for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
380 ASSERT(tl_model.class_id[tree_id] == -1,
"Tree %d has invalid class assignment", tree_id);
383 auto tree_count = num_trees(tl_model);
384 for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
385 ASSERT(tl_model.class_id[tree_id] ==
int(tree_id % tl_model.num_class[0]),
386 "Tree %d has invalid class assignment",
392 for (std::int32_t class_id = 1; class_id < tl_model.num_class[0]; ++class_id) {
393 ASSERT(tl_model.base_scores[0] == tl_model.base_scores[class_id],
394 "base_scores must be identical for all classes");
398 auto num_feature = get_num_feature(tl_model);
399 auto max_num_categories = get_max_num_categories(tl_model);
400 auto num_categorical_nodes = get_num_categorical_nodes(tl_model);
401 auto num_leaf_vector_nodes = get_num_leaf_vector_nodes(tl_model);
402 auto use_double_thresholds = use_double_precision.value_or(uses_double_thresholds(tl_model));
404 auto offsets = get_offsets(tl_model);
405 auto max_offset = *std::max_element(std::begin(offsets), std::end(offsets));
410 num_categorical_nodes,
412 num_leaf_vector_nodes,
414 auto num_class = get_num_class(tl_model);
415 return forest_model{import_to_specific_variant<
index_type{}>(variant_index,
454 std::optional<bool> use_double_precision = std::nullopt,
459 auto result = forest_model{};
461 case tree_layout::depth_first:
462 result = treelite_importer<tree_layout::depth_first>{}.import(
463 tl_model, align_bytes, use_double_precision, dev_type, device, stream);
465 case tree_layout::breadth_first:
466 result = treelite_importer<tree_layout::breadth_first>{}.import(
467 tl_model, align_bytes, use_double_precision, dev_type, device, stream);
469 case tree_layout::layered_children_together:
470 result = treelite_importer<tree_layout::layered_children_together>{}.import(
471 tl_model, align_bytes, use_double_precision, dev_type, device, stream);
504 std::optional<bool> use_double_precision = std::nullopt,
512 use_double_precision,
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
std::unique_ptr< treelite::Model > convert_degenerate_trees(treelite::Model const &tl_model)
Definition: degenerate_trees.hpp:31
tree_layout
Definition: tree_layout.hpp:20
@ layered_children_together
element_op
Definition: postproc_ops.hpp:29
uint32_t index_type
Definition: index_type.hpp:21
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:432
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:452
auto import_from_treelite_handle(TreeliteModelHandle tl_handle, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:501
row_op
Definition: postproc_ops.hpp:22
auto import_from_treelite_model(treelite::Model const &tl_model, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:451
void tree_transform(treelite::Model const &tl_model, iter_t out_iter, lambda_t &&lambda)
Definition: treelite.hpp:177
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite.hpp:188
@ layered_children_together
Definition: dbscan.hpp:30
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: treelite_importer.hpp:43
element_op element
Definition: treelite_importer.hpp:44
double constant
Definition: treelite_importer.hpp:46
row_op row
Definition: treelite_importer.hpp:45
Definition: forest_model.hpp:38
Definition: exceptions.hpp:36
HOST DEVICE constexpr auto is_categorical() const
Definition: node.hpp:166
HOST DEVICE constexpr auto is_leaf() const
Definition: node.hpp:156
Definition: treelite_importer.hpp:56
auto import_to_specific_variant(index_type target_variant_index, treelite::Model const &tl_model, index_type num_class, index_type num_feature, index_type max_num_categories, std::vector< index_type > const &offsets, index_type align_bytes=index_type{}, raft_proto::device_type mem_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:247
auto uses_double_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:218
auto get_num_feature(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:114
auto get_tree_sizes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:101
auto get_num_leaf_vector_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:139
auto get_num_class(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:109
auto get_postproc_params(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:172
static constexpr auto const traversal_order
Definition: treelite_importer.hpp:57
auto get_node_count(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:70
auto get_num_categorical_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:129
auto get_bias(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:167
auto uses_double_thresholds(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:207
auto uses_integer_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:230
auto get_average_factor(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:151
auto get_offsets(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:77
auto get_max_num_categories(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:119
auto num_trees(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:93
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23