19 #include <type_traits>
26 namespace experimental {
49 template <
bool has_vector_leaves,
50 bool has_categorical_nodes,
53 typename node_id_mapping_t = std::nullptr_t>
55 io_t
const* __restrict__ row,
56 node_t
const* __restrict__ first_root_node =
nullptr,
57 node_id_mapping_t node_id_mapping =
nullptr)
60 auto cur_node = *
node;
62 auto input_val = row[cur_node.feature_index()];
63 auto condition =
true;
64 if constexpr (has_categorical_nodes) {
65 if (cur_node.is_categorical()) {
66 auto valid_categories = categorical_set_type{
68 condition = valid_categories.test(input_val) && !isnan(input_val);
70 condition = (input_val < cur_node.threshold());
73 condition = (input_val < cur_node.threshold());
75 if (!condition && cur_node.default_distant()) { condition = isnan(input_val); }
78 }
while (!cur_node.is_leaf());
79 if constexpr (std::is_same_v<node_id_mapping_t, std::nullptr_t>) {
80 return cur_node.template output<has_vector_leaves>();
82 return node_id_mapping[
node - first_root_node];
113 template <
bool has_vector_leaves,
116 typename categorical_storage_t,
117 typename node_id_mapping_t = std::nullptr_t>
119 io_t
const* __restrict__ row,
120 categorical_storage_t
const* __restrict__ categorical_storage,
121 node_t
const* __restrict__ first_root_node =
nullptr,
122 node_id_mapping_t node_id_mapping =
nullptr)
125 auto cur_node = *
node;
127 auto input_val = row[cur_node.feature_index()];
129 if (!isnan(input_val)) {
130 if (cur_node.is_categorical()) {
131 auto valid_categories =
132 categorical_set_type{categorical_storage + cur_node.index() + 1,
133 uint32_t(categorical_storage[cur_node.index()])};
134 condition = valid_categories.test(input_val);
136 condition = (input_val < cur_node.threshold());
141 }
while (!cur_node.is_leaf());
142 if constexpr (std::is_same_v<node_id_mapping_t, std::nullptr_t>) {
143 return cur_node.template output<has_vector_leaves>();
145 return node_id_mapping[
node - first_root_node];
167 template <
bool has_vector_leaves,
168 bool has_categorical_nodes,
169 bool has_nonlocal_categories,
173 typename categorical_data_t>
176 io_t
const* __restrict__ row,
177 categorical_data_t categorical_data)
179 using node_t =
typename forest_t::node_type;
180 if constexpr (predict_leaf) {
182 if constexpr (has_nonlocal_categories) {
189 leaf_node_id = evaluate_tree_impl<has_vector_leaves, has_categorical_nodes>(
197 auto tree_output = std::conditional_t<has_vector_leaves,
199 typename node_t::threshold_type>{};
200 if constexpr (has_nonlocal_categories) {
201 tree_output = evaluate_tree_impl<has_vector_leaves>(
204 tree_output = evaluate_tree_impl<has_vector_leaves, has_categorical_nodes>(
#define DEVICE
Definition: gpu_support.hpp:35
#define HOST
Definition: gpu_support.hpp:34
HOST DEVICE auto evaluate_tree(forest_t const &forest, index_type tree_index, io_t const *__restrict__ row, categorical_data_t categorical_data)
Definition: evaluate_tree.hpp:174
HOST DEVICE auto evaluate_tree_impl(node_t const *__restrict__ node, io_t const *__restrict__ row, node_t const *__restrict__ first_root_node=nullptr, node_id_mapping_t node_id_mapping=nullptr)
Definition: evaluate_tree.hpp:54
uint32_t index_type
Definition: index_type.hpp:21
forest< real_t > * forest_t
Definition: fil.h:89
Definition: dbscan.hpp:30
Definition: bitset.hpp:33
Definition: forest.hpp:36
HOST DEVICE auto * get_tree_root(index_type tree_index) const
Definition: forest.hpp:58
HOST DEVICE const auto * get_node_id_mapping() const
Definition: forest.hpp:65
HOST DEVICE constexpr auto default_distant() const
Definition: node.hpp:161
HOST DEVICE constexpr auto child_offset(bool condition) const
Definition: node.hpp:171
HOST DEVICE auto const & index() const
Definition: node.hpp:188