19 #include <type_traits> 
   48 template <
bool has_vector_leaves,
 
   49           bool has_categorical_nodes,
 
   52           typename node_id_mapping_t = std::nullptr_t>
 
   54                                     io_t 
const* __restrict__ row,
 
   55                                     node_t 
const* __restrict__ first_root_node = 
nullptr,
 
   56                                     node_id_mapping_t node_id_mapping          = 
nullptr)
 
   59   auto cur_node              = *
node;
 
   61     auto input_val = row[cur_node.feature_index()];
 
   62     auto condition = 
true;
 
   63     if constexpr (has_categorical_nodes) {
 
   64       if (cur_node.is_categorical()) {
 
   65         auto valid_categories = categorical_set_type{
 
   67         condition = valid_categories.test(input_val) && !isnan(input_val);
 
   69         condition = (input_val < cur_node.threshold());
 
   72       condition = (input_val < cur_node.threshold());
 
   74     if (!condition && cur_node.default_distant()) { condition = isnan(input_val); }
 
   77   } 
while (!cur_node.is_leaf());
 
   78   if constexpr (std::is_same_v<node_id_mapping_t, std::nullptr_t>) {
 
   79     return cur_node.template output<has_vector_leaves>();
 
   81     return node_id_mapping[
node - first_root_node];
 
  112 template <
bool has_vector_leaves,
 
  115           typename categorical_storage_t,
 
  116           typename node_id_mapping_t = std::nullptr_t>
 
  118                                     io_t 
const* __restrict__ row,
 
  119                                     categorical_storage_t 
const* __restrict__ categorical_storage,
 
  120                                     node_t 
const* __restrict__ first_root_node = 
nullptr,
 
  121                                     node_id_mapping_t node_id_mapping          = 
nullptr)
 
  124   auto cur_node              = *
node;
 
  126     auto input_val = row[cur_node.feature_index()];
 
  128     if (!isnan(input_val)) {
 
  129       if (cur_node.is_categorical()) {
 
  130         auto valid_categories =
 
  131           categorical_set_type{categorical_storage + cur_node.index() + 1,
 
  132                                uint32_t(categorical_storage[cur_node.index()])};
 
  133         condition = valid_categories.test(input_val);
 
  135         condition = (input_val < cur_node.threshold());
 
  140   } 
while (!cur_node.is_leaf());
 
  141   if constexpr (std::is_same_v<node_id_mapping_t, std::nullptr_t>) {
 
  142     return cur_node.template output<has_vector_leaves>();
 
  144     return node_id_mapping[
node - first_root_node];
 
  166 template <
bool has_vector_leaves,
 
  167           bool has_categorical_nodes,
 
  168           bool has_nonlocal_categories,
 
  172           typename categorical_data_t>
 
  175                                io_t 
const* __restrict__ row,
 
  176                                categorical_data_t categorical_data)
 
  178   using node_t = 
typename forest_t::node_type;
 
  179   if constexpr (predict_leaf) {
 
  181     if constexpr (has_nonlocal_categories) {
 
  188       leaf_node_id = evaluate_tree_impl<has_vector_leaves, has_categorical_nodes>(
 
  196     auto tree_output = std::conditional_t<has_vector_leaves,
 
  198                                           typename node_t::threshold_type>{};
 
  199     if constexpr (has_nonlocal_categories) {
 
  200       tree_output = evaluate_tree_impl<has_vector_leaves>(
 
  203       tree_output = evaluate_tree_impl<has_vector_leaves, has_categorical_nodes>(
 
#define DEVICE
Definition: gpu_support.hpp:35
 
#define HOST
Definition: gpu_support.hpp:34
 
HOST DEVICE auto evaluate_tree_impl(node_t const *__restrict__ node, io_t const *__restrict__ row, node_t const *__restrict__ first_root_node=nullptr, node_id_mapping_t node_id_mapping=nullptr)
Definition: evaluate_tree.hpp:53
 
HOST DEVICE auto evaluate_tree(forest_t const &forest, index_type tree_index, io_t const *__restrict__ row, categorical_data_t categorical_data)
Definition: evaluate_tree.hpp:173
 
uint32_t index_type
Definition: index_type.hpp:20
 
Definition: dbscan.hpp:29
 
Definition: bitset.hpp:32
 
Definition: forest.hpp:35
 
HOST DEVICE auto * get_tree_root(index_type tree_index) const
Definition: forest.hpp:57
 
HOST DEVICE const auto * get_node_id_mapping() const
Definition: forest.hpp:64
 
HOST DEVICE auto const  & index() const
Definition: node.hpp:188
 
HOST DEVICE constexpr auto child_offset(bool condition) const
Definition: node.hpp:170
 
HOST DEVICE constexpr auto default_distant() const
Definition: node.hpp:160