evaluate_tree.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 #include <stdint.h>
7 
8 #include <type_traits>
9 #ifndef __CUDACC__
10 #include <math.h>
11 #endif
14 namespace ML {
15 namespace fil {
16 namespace detail {
17 
18 /*
19  * Evaluate a single tree on a single row.
20  * If node_id_mapping is not-nullptr, this kernel outputs leaf node's ID
21  * instead of the leaf value.
22  *
23  * @tparam has_vector_leaves Whether or not this tree has vector leaves
24  * @tparam has_categorical_nodes Whether or not this tree has any nodes with
25  * categorical splits
26  * @tparam node_t The type of nodes in this tree
27  * @tparam io_t The type used for input to and output from this tree (typically
28  * either floats or doubles)
29  * @tparam node_id_mapping_t If non-nullptr_t, this indicates the type we expect for
30  * node_id_mapping.
31  * @param node Pointer to the root node of this tree
32  * @param row Pointer to the input data for this row
33  * @param first_root_node Pointer to the root node of the first tree.
34  * @param node_id_mapping Array representing the mapping from internal node IDs to
35  * final leaf ID outputs
36  */
37 template <bool has_vector_leaves,
38  bool has_categorical_nodes,
39  typename node_t,
40  typename io_t,
41  typename node_id_mapping_t = std::nullptr_t>
42 HOST DEVICE auto evaluate_tree_impl(node_t const* __restrict__ node,
43  io_t const* __restrict__ row,
44  node_t const* __restrict__ first_root_node = nullptr,
45  node_id_mapping_t node_id_mapping = nullptr)
46 {
47  using categorical_set_type = bitset<uint32_t, typename node_t::index_type const>;
48  auto cur_node = *node;
49  do {
50  auto input_val = row[cur_node.feature_index()];
51  auto condition = true;
52  if constexpr (has_categorical_nodes) {
53  if (cur_node.is_categorical()) {
54  auto valid_categories = categorical_set_type{
55  &cur_node.index(), uint32_t(sizeof(typename node_t::index_type) * 8)};
56  condition = valid_categories.test(input_val) && !isnan(input_val);
57  } else {
58  condition = (input_val < cur_node.threshold());
59  }
60  } else {
61  condition = (input_val < cur_node.threshold());
62  }
63  if (!condition && cur_node.default_distant()) { condition = isnan(input_val); }
64  node += cur_node.child_offset(condition);
65  cur_node = *node;
66  } while (!cur_node.is_leaf());
67  if constexpr (std::is_same_v<node_id_mapping_t, std::nullptr_t>) {
68  return cur_node.template output<has_vector_leaves>();
69  } else {
70  return node_id_mapping[node - first_root_node];
71  }
72 }
73 
74 /*
75  * Evaluate a single tree which requires external categorical storage on a
76  * single node.
77  * If node_id_mapping is not-nullptr, this kernel outputs leaf node's ID
78  * instead of the leaf value.
79  *
80  * For non-categorical models and models with a relatively small number of
81  * categories for any feature, all information necessary for model evaluation
82  * can be stored on a single node. If the number of categories for any
83  * feature exceeds the available space on a node, however, the
84  * categorical split data must be stored external to the node. We pass a
85  * pointer to this external data and reconstruct bitsets from it indicating
86  * the positive and negative categories for each categorical node.
87  *
88  * @tparam has_vector_leaves Whether or not this tree has vector leaves
89  * @tparam node_t The type of nodes in this tree
90  * @tparam io_t The type used for input to and output from this tree (typically
91  * either floats or doubles)
92  * @tparam categorical_storage_t The underlying type used for storing
93  * categorical data (typically char)
94  * @tparam node_id_mapping_t If non-nullptr_t, this indicates the type we expect for
95  * node_id_mapping.
96  * @param node Pointer to the root node of this tree
97  * @param row Pointer to the input data for this row
98  * @param categorical_storage Pointer to where categorical split data is
99  * stored.
100  */
101 template <bool has_vector_leaves,
102  typename node_t,
103  typename io_t,
104  typename categorical_storage_t,
105  typename node_id_mapping_t = std::nullptr_t>
106 HOST DEVICE auto evaluate_tree_impl(node_t const* __restrict__ node,
107  io_t const* __restrict__ row,
108  categorical_storage_t const* __restrict__ categorical_storage,
109  node_t const* __restrict__ first_root_node = nullptr,
110  node_id_mapping_t node_id_mapping = nullptr)
111 {
112  using categorical_set_type = bitset<uint32_t, categorical_storage_t const>;
113  auto cur_node = *node;
114  do {
115  auto input_val = row[cur_node.feature_index()];
116  auto condition = cur_node.default_distant();
117  if (!isnan(input_val)) {
118  if (cur_node.is_categorical()) {
119  auto valid_categories =
120  categorical_set_type{categorical_storage + cur_node.index() + 1,
121  uint32_t(categorical_storage[cur_node.index()])};
122  condition = valid_categories.test(input_val);
123  } else {
124  condition = (input_val < cur_node.threshold());
125  }
126  }
127  node += cur_node.child_offset(condition);
128  cur_node = *node;
129  } while (!cur_node.is_leaf());
130  if constexpr (std::is_same_v<node_id_mapping_t, std::nullptr_t>) {
131  return cur_node.template output<has_vector_leaves>();
132  } else {
133  return node_id_mapping[node - first_root_node];
134  }
135 }
136 
155 template <bool has_vector_leaves,
156  bool has_categorical_nodes,
157  bool has_nonlocal_categories,
158  bool predict_leaf,
159  typename forest_t,
160  typename io_t,
161  typename categorical_data_t>
162 HOST DEVICE auto evaluate_tree(forest_t const& forest,
163  index_type tree_index,
164  io_t const* __restrict__ row,
165  categorical_data_t categorical_data)
166 {
167  using node_t = typename forest_t::node_type;
168  if constexpr (predict_leaf) {
169  auto leaf_node_id = index_type{};
170  if constexpr (has_nonlocal_categories) {
171  leaf_node_id = evaluate_tree_impl<has_vector_leaves>(forest.get_tree_root(tree_index),
172  row,
173  categorical_data,
176  } else {
177  leaf_node_id = evaluate_tree_impl<has_vector_leaves, has_categorical_nodes>(
178  forest.get_tree_root(tree_index),
179  row,
182  }
183  return leaf_node_id;
184  } else {
185  auto tree_output = std::conditional_t<has_vector_leaves,
186  typename node_t::index_type,
187  typename node_t::threshold_type>{};
188  if constexpr (has_nonlocal_categories) {
189  tree_output = evaluate_tree_impl<has_vector_leaves>(
190  forest.get_tree_root(tree_index), row, categorical_data);
191  } else {
192  tree_output = evaluate_tree_impl<has_vector_leaves, has_categorical_nodes>(
193  forest.get_tree_root(tree_index), row);
194  }
195  return tree_output;
196  }
197 }
198 
199 } // namespace detail
200 } // namespace fil
201 } // namespace ML
#define DEVICE
Definition: gpu_support.hpp:24
#define HOST
Definition: gpu_support.hpp:23
HOST DEVICE auto evaluate_tree_impl(node_t const *__restrict__ node, io_t const *__restrict__ row, node_t const *__restrict__ first_root_node=nullptr, node_id_mapping_t node_id_mapping=nullptr)
Definition: evaluate_tree.hpp:42
HOST DEVICE auto evaluate_tree(forest_t const &forest, index_type tree_index, io_t const *__restrict__ row, categorical_data_t categorical_data)
Definition: evaluate_tree.hpp:162
uint32_t index_type
Definition: index_type.hpp:9
Definition: dbscan.hpp:18
Definition: bitset.hpp:21
Definition: forest.hpp:24
HOST DEVICE auto * get_tree_root(index_type tree_index) const
Definition: forest.hpp:48
HOST DEVICE const auto * get_node_id_mapping() const
Definition: forest.hpp:55
Definition: node.hpp:81
HOST DEVICE auto const & index() const
Definition: node.hpp:177
HOST DEVICE constexpr auto child_offset(bool condition) const
Definition: node.hpp:159
HOST DEVICE constexpr auto default_distant() const
Definition: node.hpp:149