26 #include <treelite/c_api.h>
27 #include <treelite/enum/task_type.h>
28 #include <treelite/enum/tree_node_type.h>
29 #include <treelite/enum/typeinfo.h>
30 #include <treelite/tree.h>
39 namespace experimental {
45 template <tree_layout layout,
typename T>
48 std::conditional_t<layout == tree_layout::depth_first, std::stack<T>, std::queue<T>>;
49 void add(T
const& val) { data_.push(val); }
50 void add(T
const& hot, T
const& distant)
63 auto result = data_.top();
67 auto result = data_.front();
80 [[nodiscard]]
auto empty() {
return data_.empty(); }
81 auto size() {
return data_.size(); }
99 template <tree_layout layout>
101 template <
typename tl_threshold_t,
typename tl_output_t>
103 treelite::Tree<tl_threshold_t, tl_output_t>
const&
tree;
112 auto result = std::vector<tl_output_t>{};
127 return tree.NodeType(
node_id) == treelite::TreeNodeType::kCategoricalTestNode;
136 result = (default_child ==
tree.RightChild(
node_id));
138 result = (default_child ==
tree.LeftChild(
node_id));
142 if (tl_operator == treelite::Operator::kLT || tl_operator == treelite::Operator::kLE) {
143 result = (default_child ==
tree.LeftChild(
node_id));
145 result = (default_child ==
tree.RightChild(
node_id));
155 auto result = decltype(
tree.CategoryList(
node_id)){};
163 return tl_operator == treelite::Operator::kGT || tl_operator == treelite::Operator::kLE;
167 template <
typename tl_threshold_t,
typename tl_output_t,
typename lambda_t>
168 void node_for_each(treelite::Tree<tl_threshold_t, tl_output_t>
const& tl_tree, lambda_t&& lambda)
170 using node_index_t = decltype(tl_tree.LeftChild(0));
172 to_be_visited.
add(node_index_t{});
176 parent_indices.
add(cur_index);
178 while (!to_be_visited.empty()) {
179 auto node_id = to_be_visited.next();
180 auto remaining_size = to_be_visited.size();
183 tl_tree, node_id, parent_indices.next(), cur_index};
184 lambda(tl_node, node_id);
186 if (!tl_tree.IsLeaf(node_id)) {
187 auto tl_left_id = tl_tree.LeftChild(node_id);
188 auto tl_right_id = tl_tree.RightChild(node_id);
189 auto tl_operator = tl_tree.ComparisonOp(node_id);
190 if (!tl_node.is_categorical()) {
191 if (tl_operator == treelite::Operator::kLT || tl_operator == treelite::Operator::kLE) {
192 to_be_visited.add(tl_right_id, tl_left_id);
193 }
else if (tl_operator == treelite::Operator::kGT ||
194 tl_operator == treelite::Operator::kGE) {
195 to_be_visited.add(tl_left_id, tl_right_id);
200 if (tl_tree.CategoryListRightChild(node_id)) {
201 to_be_visited.add(tl_left_id, tl_right_id);
203 to_be_visited.add(tl_right_id, tl_left_id);
206 parent_indices.add(cur_index, cur_index);
212 template <
typename tl_threshold_t,
typename tl_output_t,
typename iter_t,
typename lambda_t>
217 node_for_each(tl_tree, [&output_iter, &lambda](
auto&& tl_node,
int tl_node_id) {
218 *output_iter = lambda(tl_node);
223 template <
typename tl_threshold_t,
typename tl_output_t,
typename T,
typename lambda_t>
229 node_for_each(tl_tree, [&result, &lambda](
auto&& tl_node,
int tl_node_id) {
230 result = lambda(result, tl_node);
235 template <
typename tl_threshold_t,
typename tl_output_t>
236 auto get_nodes(treelite::Tree<tl_threshold_t, tl_output_t>
const& tl_tree)
238 auto result = std::vector<treelite_node<tl_threshold_t, tl_output_t>>{};
239 result.reserve(tl_tree.num_nodes);
244 template <
typename tl_threshold_t,
typename tl_output_t>
245 auto get_offsets(treelite::Tree<tl_threshold_t, tl_output_t>
const& tl_tree)
247 auto result = std::vector<index_type>(tl_tree.num_nodes);
249 for (
auto i =
index_type{}; i < nodes.size(); ++i) {
253 result[nodes[i].parent_index] =
index_type{i - nodes[i].parent_index};
259 template <
typename lambda_t>
263 [&lambda](
auto&& concrete_tl_model) {
265 std::begin(concrete_tl_model.trees), std::end(concrete_tl_model.trees), lambda);
270 template <
typename iter_t,
typename lambda_t>
271 void tree_transform(treelite::Model
const& tl_model, iter_t output_iter, lambda_t&& lambda)
274 [&output_iter, &lambda](
auto&& concrete_tl_model) {
276 std::end(concrete_tl_model.trees),
283 template <
typename T,
typename lambda_t>
287 tree_for_each(tl_model, [&result, &lambda](
auto&& tree) { result = lambda(result, tree); });
294 std::visit([&result](
auto&& concrete_tl_model) { result = concrete_tl_model.trees.size(); },
301 auto result = std::vector<std::vector<index_type>>{};
304 tl_model, std::back_inserter(result), [
this](
auto&& tree) {
return get_offsets(tree); });
310 auto result = std::vector<index_type>{};
312 tl_model, std::back_inserter(result), [](
auto&& tree) {
return tree.num_nodes; });
318 return static_cast<index_type>(tl_model.num_class[0]);
323 return static_cast<index_type>(tl_model.num_feature);
329 return node_accumulate(tree, accum, [](
auto&& cur_accum,
auto&& tl_node) {
330 auto result = cur_accum;
331 for (
auto&& cat : tl_node.categories()) {
332 result = (cat + 1 > result) ? cat + 1 : result;
342 return node_accumulate(tree, accum, [](
auto&& cur_accum,
auto&& tl_node) {
343 return cur_accum + tl_node.is_categorical();
351 return node_accumulate(tree, accum, [](
auto&& cur_accum,
auto&& tl_node) {
352 return cur_accum + (tl_node.is_leaf() && tl_node.get_output().size() > 1);
359 auto result =
double{};
360 if (tl_model.average_tree_output) {
361 if (tl_model.task_type == treelite::TaskType::kMultiClf &&
362 tl_model.leaf_vector_shape[1] == 1) {
363 result =
num_trees(tl_model) / tl_model.num_class[0];
375 return static_cast<double>(tl_model.base_scores[0]);
381 auto tl_pred_transform = tl_model.postprocessor;
382 if (tl_pred_transform == std::string{
"identity"} ||
383 tl_pred_transform == std::string{
"identity_multiclass"}) {
386 }
else if (tl_pred_transform == std::string{
"signed_square"}) {
388 }
else if (tl_pred_transform == std::string{
"hinge"}) {
390 }
else if (tl_pred_transform == std::string{
"sigmoid"}) {
391 result.constant = tl_model.sigmoid_alpha;
393 }
else if (tl_pred_transform == std::string{
"exponential"}) {
395 }
else if (tl_pred_transform == std::string{
"exponential_standard_ratio"}) {
396 result.constant = -tl_model.ratio_c / std::log(2);
398 }
else if (tl_pred_transform == std::string{
"logarithm_one_plus_exp"}) {
400 }
else if (tl_pred_transform == std::string{
"max_index"}) {
402 }
else if (tl_pred_transform == std::string{
"softmax"}) {
404 }
else if (tl_pred_transform == std::string{
"multiclass_ova"}) {
405 result.constant = tl_model.sigmoid_alpha;
416 switch (tl_model.GetThresholdType()) {
417 case treelite::TypeInfo::kFloat64: result =
true;
break;
418 case treelite::TypeInfo::kFloat32: result =
false;
break;
427 switch (tl_model.GetThresholdType()) {
428 case treelite::TypeInfo::kFloat64: result =
true;
break;
429 case treelite::TypeInfo::kFloat32: result =
false;
break;
430 case treelite::TypeInfo::kUInt32: result =
false;
break;
439 switch (tl_model.GetThresholdType()) {
440 case treelite::TypeInfo::kFloat64: result =
false;
break;
441 case treelite::TypeInfo::kFloat32: result =
false;
break;
442 case treelite::TypeInfo::kUInt32: result =
true;
break;
452 template <index_type variant_index>
454 treelite::Model
const& tl_model,
458 std::vector<std::vector<index_type>>
const& offsets,
465 if constexpr (variant_index != std::variant_size_v<decision_forest_variant>) {
466 if (variant_index == target_variant_index) {
467 using forest_model_t = std::variant_alternative_t<variant_index, decision_forest_variant>;
469 detail::decision_forest_builder<forest_model_t>(max_num_categories, align_bytes);
472 tree_for_each(tl_model, [
this, &builder, &tree_index, &offsets](
auto&& tree) {
473 builder.start_new_tree();
476 tree, [&builder, &tree_index, &node_index, &offsets](
auto&& node,
int tl_node_id) {
477 if (node.is_leaf()) {
478 auto output = node.get_output();
479 builder.set_output_size(output.size());
480 if (output.size() > index_type{1}) {
481 builder.add_leaf_vector_node(std::begin(output), std::end(output), tl_node_id);
483 builder.add_node(
typename forest_model_t::io_type(output[0]), tl_node_id,
true);
486 if (node.is_categorical()) {
487 auto categories = node.get_categories();
488 builder.add_categorical_node(std::begin(categories),
489 std::end(categories),
491 node.default_distant(),
493 offsets[tree_index][node_index]);
495 builder.add_node(
typename forest_model_t::threshold_type(node.threshold()),
498 node.default_distant(),
501 offsets[tree_index][node_index],
502 node.is_inclusive());
511 builder.set_bias(
get_bias(tl_model));
513 builder.set_element_postproc(postproc_params.element);
514 builder.set_row_postproc(postproc_params.row);
515 builder.set_postproc_constant(postproc_params.constant);
517 result.template emplace<variant_index>(
518 builder.get_decision_forest(num_feature, num_class, mem_type, device, stream));
520 result = import_to_specific_variant<variant_index + 1>(target_variant_index,
557 auto import(treelite::Model
const& tl_model,
559 std::optional<bool> use_double_precision = std::nullopt,
564 ASSERT(tl_model.num_target == 1,
"FIL does not support multi-target model");
566 if (tl_model.task_type == treelite::TaskType::kMultiClf) {
568 if (tl_model.leaf_vector_shape[1] > 1) {
569 ASSERT(tl_model.leaf_vector_shape[1] == tl_model.num_class[0],
570 "Vector leaf must be equal to num_class = %d",
571 tl_model.num_class[0]);
572 auto tree_count = num_trees(tl_model);
573 for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
574 ASSERT(tl_model.class_id[tree_id] == -1,
"Tree %d has invalid class assignment", tree_id);
577 auto tree_count = num_trees(tl_model);
578 for (decltype(tree_count) tree_id = 0; tree_id < tree_count; ++tree_id) {
579 ASSERT(tl_model.class_id[tree_id] == tree_id % tl_model.num_class[0],
580 "Tree %d has invalid class assignment",
586 for (std::int32_t class_id = 1; class_id < tl_model.num_class[0]; ++class_id) {
587 ASSERT(tl_model.base_scores[0] == tl_model.base_scores[class_id],
588 "base_scores must be identical for all classes");
592 auto num_feature = get_num_feature(tl_model);
593 auto max_num_categories = get_max_num_categories(tl_model);
594 auto num_categorical_nodes = get_num_categorical_nodes(tl_model);
595 auto num_leaf_vector_nodes = get_num_leaf_vector_nodes(tl_model);
596 auto use_double_thresholds = use_double_precision.value_or(uses_double_thresholds(tl_model));
598 auto offsets = get_offsets(tl_model);
599 auto max_offset = std::accumulate(
603 [&offsets](
auto&& cur_max,
auto&& tree_offsets) {
605 *std::max_element(std::begin(tree_offsets), std::end(tree_offsets)));
607 auto tree_sizes = std::vector<index_type>{};
610 std::back_inserter(tree_sizes),
611 [](
auto&& tree_offsets) {
return tree_offsets.size(); });
616 num_categorical_nodes,
618 num_leaf_vector_nodes,
620 auto num_class = get_num_class(tl_model);
621 return forest_model{import_to_specific_variant<
index_type{}>(variant_index,
660 std::optional<bool> use_double_precision = std::nullopt,
665 auto result = forest_model{};
667 case tree_layout::depth_first:
668 result = treelite_importer<tree_layout::depth_first>{}.import(
669 tl_model, align_bytes, use_double_precision, dev_type, device, stream);
671 case tree_layout::breadth_first:
672 result = treelite_importer<tree_layout::breadth_first>{}.import(
673 tl_model, align_bytes, use_double_precision, dev_type, device, stream);
706 std::optional<bool> use_double_precision = std::nullopt,
714 use_double_precision,
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
tree_layout
Definition: tree_layout.hpp:20
element_op
Definition: postproc_ops.hpp:29
uint32_t index_type
Definition: index_type.hpp:21
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:436
auto import_from_treelite_handle(TreeliteModelHandle tl_handle, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:703
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:416
row_op
Definition: postproc_ops.hpp:22
auto import_from_treelite_model(treelite::Model const &tl_model, tree_layout layout=preferred_tree_layout, index_type align_bytes=index_type{}, std::optional< bool > use_double_precision=std::nullopt, raft_proto::device_type dev_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:657
void transform(const raft::handle_t &handle, const KMeansParams ¶ms, const float *centroids, const float *X, int n_samples, int n_features, float *X_new)
Transform X to a cluster-distance space.
Definition: dbscan.hpp:30
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: treelite_importer.hpp:87
element_op element
Definition: treelite_importer.hpp:88
double constant
Definition: treelite_importer.hpp:90
row_op row
Definition: treelite_importer.hpp:89
Definition: treelite_importer.hpp:46
auto empty()
Definition: treelite_importer.hpp:80
auto next()
Definition: treelite_importer.hpp:60
auto peek()
Definition: treelite_importer.hpp:72
void add(T const &val)
Definition: treelite_importer.hpp:49
auto size()
Definition: treelite_importer.hpp:81
void add(T const &hot, T const &distant)
Definition: treelite_importer.hpp:50
std::conditional_t< layout==tree_layout::depth_first, std::stack< T >, std::queue< T > > backing_container_t
Definition: treelite_importer.hpp:48
Definition: exceptions.hpp:36
Definition: treelite_importer.hpp:102
auto get_feature()
Definition: treelite_importer.hpp:123
auto is_categorical()
Definition: treelite_importer.hpp:125
auto get_categories()
Definition: treelite_importer.hpp:121
auto threshold()
Definition: treelite_importer.hpp:151
index_type parent_index
Definition: treelite_importer.hpp:105
auto is_leaf()
Definition: treelite_importer.hpp:108
auto get_output()
Definition: treelite_importer.hpp:110
auto default_distant()
Definition: treelite_importer.hpp:130
index_type own_index
Definition: treelite_importer.hpp:106
auto is_inclusive()
Definition: treelite_importer.hpp:160
int node_id
Definition: treelite_importer.hpp:104
treelite::Tree< tl_threshold_t, tl_output_t > const & tree
Definition: treelite_importer.hpp:103
auto categories()
Definition: treelite_importer.hpp:153
Definition: treelite_importer.hpp:100
auto get_nodes(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree)
Definition: treelite_importer.hpp:236
void tree_transform(treelite::Model const &tl_model, iter_t output_iter, lambda_t &&lambda)
Definition: treelite_importer.hpp:271
auto node_accumulate(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, T init, lambda_t &&lambda)
Definition: treelite_importer.hpp:224
void node_transform(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, iter_t output_iter, lambda_t &&lambda)
Definition: treelite_importer.hpp:213
auto uses_double_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:424
auto get_num_feature(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:321
auto get_tree_sizes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:308
auto import_to_specific_variant(index_type target_variant_index, treelite::Model const &tl_model, index_type num_class, index_type num_feature, index_type max_num_categories, std::vector< std::vector< index_type >> const &offsets, index_type align_bytes=index_type{}, raft_proto::device_type mem_type=raft_proto::device_type::cpu, int device=0, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: treelite_importer.hpp:453
auto tree_accumulate(treelite::Model const &tl_model, T init, lambda_t &&lambda)
Definition: treelite_importer.hpp:284
auto get_num_leaf_vector_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:348
void node_for_each(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree, lambda_t &&lambda)
Definition: treelite_importer.hpp:168
auto get_num_class(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:316
auto get_postproc_params(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:378
auto get_offsets(treelite::Tree< tl_threshold_t, tl_output_t > const &tl_tree)
Definition: treelite_importer.hpp:245
auto get_num_categorical_nodes(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:339
auto get_bias(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:373
auto uses_double_thresholds(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:413
auto uses_integer_outputs(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:436
auto get_average_factor(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:357
auto get_offsets(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:299
void tree_for_each(treelite::Model const &tl_model, lambda_t &&lambda)
Definition: treelite_importer.hpp:260
auto get_max_num_categories(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:326
auto num_trees(treelite::Model const &tl_model)
Definition: treelite_importer.hpp:291
void * TreeliteModelHandle
Definition: treelite_defs.hpp:23