evaluate_tree.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 #include <stdint.h>
18 
19 #include <type_traits>
20 #ifndef __CUDACC__
21 #include <math.h>
22 #endif
25 namespace ML {
26 namespace experimental {
27 namespace fil {
28 namespace detail {
29 
30 /*
31  * Evaluate a single tree on a single row.
32  * If node_id_mapping is not-nullptr, this kernel outputs leaf node's ID
33  * instead of the leaf value.
34  *
35  * @tparam has_vector_leaves Whether or not this tree has vector leaves
36  * @tparam has_categorical_nodes Whether or not this tree has any nodes with
37  * categorical splits
38  * @tparam node_t The type of nodes in this tree
39  * @tparam io_t The type used for input to and output from this tree (typically
40  * either floats or doubles)
41  * @tparam node_id_mapping_t If non-nullptr_t, this indicates the type we expect for
42  * node_id_mapping.
43  * @param node Pointer to the root node of this tree
44  * @param row Pointer to the input data for this row
45  * @param first_root_node Pointer to the root node of the first tree.
46  * @param node_id_mapping Array representing the mapping from internal node IDs to
47  * final leaf ID outputs
48  */
49 template <bool has_vector_leaves,
50  bool has_categorical_nodes,
51  typename node_t,
52  typename io_t,
53  typename node_id_mapping_t = std::nullptr_t>
54 HOST DEVICE auto evaluate_tree_impl(node_t const* __restrict__ node,
55  io_t const* __restrict__ row,
56  node_t const* __restrict__ first_root_node = nullptr,
57  node_id_mapping_t node_id_mapping = nullptr)
58 {
59  using categorical_set_type = bitset<uint32_t, typename node_t::index_type const>;
60  auto cur_node = *node;
61  do {
62  auto input_val = row[cur_node.feature_index()];
63  auto condition = true;
64  if constexpr (has_categorical_nodes) {
65  if (cur_node.is_categorical()) {
66  auto valid_categories = categorical_set_type{
67  &cur_node.index(), uint32_t(sizeof(typename node_t::index_type) * 8)};
68  condition = valid_categories.test(input_val) && !isnan(input_val);
69  } else {
70  condition = (input_val < cur_node.threshold());
71  }
72  } else {
73  condition = (input_val < cur_node.threshold());
74  }
75  if (!condition && cur_node.default_distant()) { condition = isnan(input_val); }
76  node += cur_node.child_offset(condition);
77  cur_node = *node;
78  } while (!cur_node.is_leaf());
79  if constexpr (std::is_same_v<node_id_mapping_t, std::nullptr_t>) {
80  return cur_node.template output<has_vector_leaves>();
81  } else {
82  return node_id_mapping[node - first_root_node];
83  }
84 }
85 
86 /*
87  * Evaluate a single tree which requires external categorical storage on a
88  * single node.
89  * If node_id_mapping is not-nullptr, this kernel outputs leaf node's ID
90  * instead of the leaf value.
91  *
92  * For non-categorical models and models with a relatively small number of
93  * categories for any feature, all information necessary for model evaluation
94  * can be stored on a single node. If the number of categories for any
95  * feature exceeds the available space on a node, however, the
96  * categorical split data must be stored external to the node. We pass a
97  * pointer to this external data and reconstruct bitsets from it indicating
98  * the positive and negative categories for each categorical node.
99  *
100  * @tparam has_vector_leaves Whether or not this tree has vector leaves
101  * @tparam node_t The type of nodes in this tree
102  * @tparam io_t The type used for input to and output from this tree (typically
103  * either floats or doubles)
104  * @tparam categorical_storage_t The underlying type used for storing
105  * categorical data (typically char)
106  * @tparam node_id_mapping_t If non-nullptr_t, this indicates the type we expect for
107  * node_id_mapping.
108  * @param node Pointer to the root node of this tree
109  * @param row Pointer to the input data for this row
110  * @param categorical_storage Pointer to where categorical split data is
111  * stored.
112  */
113 template <bool has_vector_leaves,
114  typename node_t,
115  typename io_t,
116  typename categorical_storage_t,
117  typename node_id_mapping_t = std::nullptr_t>
118 HOST DEVICE auto evaluate_tree_impl(node_t const* __restrict__ node,
119  io_t const* __restrict__ row,
120  categorical_storage_t const* __restrict__ categorical_storage,
121  node_t const* __restrict__ first_root_node = nullptr,
122  node_id_mapping_t node_id_mapping = nullptr)
123 {
124  using categorical_set_type = bitset<uint32_t, categorical_storage_t const>;
125  auto cur_node = *node;
126  do {
127  auto input_val = row[cur_node.feature_index()];
128  auto condition = cur_node.default_distant();
129  if (!isnan(input_val)) {
130  if (cur_node.is_categorical()) {
131  auto valid_categories =
132  categorical_set_type{categorical_storage + cur_node.index() + 1,
133  uint32_t(categorical_storage[cur_node.index()])};
134  condition = valid_categories.test(input_val);
135  } else {
136  condition = (input_val < cur_node.threshold());
137  }
138  }
139  node += cur_node.child_offset(condition);
140  cur_node = *node;
141  } while (!cur_node.is_leaf());
142  if constexpr (std::is_same_v<node_id_mapping_t, std::nullptr_t>) {
143  return cur_node.template output<has_vector_leaves>();
144  } else {
145  return node_id_mapping[node - first_root_node];
146  }
147 }
148 
167 template <bool has_vector_leaves,
168  bool has_categorical_nodes,
169  bool has_nonlocal_categories,
170  bool predict_leaf,
171  typename forest_t,
172  typename io_t,
173  typename categorical_data_t>
175  index_type tree_index,
176  io_t const* __restrict__ row,
177  categorical_data_t categorical_data)
178 {
179  using node_t = typename forest_t::node_type;
180  if constexpr (predict_leaf) {
181  auto leaf_node_id = index_type{};
182  if constexpr (has_nonlocal_categories) {
183  leaf_node_id = evaluate_tree_impl<has_vector_leaves>(forest.get_tree_root(tree_index),
184  row,
185  categorical_data,
188  } else {
189  leaf_node_id = evaluate_tree_impl<has_vector_leaves, has_categorical_nodes>(
190  forest.get_tree_root(tree_index),
191  row,
194  }
195  return leaf_node_id;
196  } else {
197  auto tree_output = std::conditional_t<has_vector_leaves,
198  typename node_t::index_type,
199  typename node_t::threshold_type>{};
200  if constexpr (has_nonlocal_categories) {
201  tree_output = evaluate_tree_impl<has_vector_leaves>(
202  forest.get_tree_root(tree_index), row, categorical_data);
203  } else {
204  tree_output = evaluate_tree_impl<has_vector_leaves, has_categorical_nodes>(
205  forest.get_tree_root(tree_index), row);
206  }
207  return tree_output;
208  }
209 }
210 
211 } // namespace detail
212 } // namespace fil
213 } // namespace experimental
214 } // namespace ML
#define DEVICE
Definition: gpu_support.hpp:35
#define HOST
Definition: gpu_support.hpp:34
HOST DEVICE auto evaluate_tree(forest_t const &forest, index_type tree_index, io_t const *__restrict__ row, categorical_data_t categorical_data)
Definition: evaluate_tree.hpp:174
HOST DEVICE auto evaluate_tree_impl(node_t const *__restrict__ node, io_t const *__restrict__ row, node_t const *__restrict__ first_root_node=nullptr, node_id_mapping_t node_id_mapping=nullptr)
Definition: evaluate_tree.hpp:54
uint32_t index_type
Definition: index_type.hpp:21
forest< real_t > * forest_t
Definition: fil.h:89
Definition: dbscan.hpp:30
Definition: bitset.hpp:33
Definition: forest.hpp:36
HOST DEVICE auto * get_tree_root(index_type tree_index) const
Definition: forest.hpp:58
HOST DEVICE const auto * get_node_id_mapping() const
Definition: forest.hpp:65
Definition: node.hpp:93
HOST DEVICE constexpr auto default_distant() const
Definition: node.hpp:161
HOST DEVICE constexpr auto child_offset(bool condition) const
Definition: node.hpp:171
HOST DEVICE auto const & index() const
Definition: node.hpp:188