node.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
10 
11 #include <iostream>
12 #include <type_traits>
13 
14 namespace ML {
15 namespace fil {
16 
17 namespace detail {
18 
19 /*
20  * Return the byte size to which a node with the given types should be aligned
21  */
22 template <typename threshold_t, typename index_t, typename metadata_storage_t, typename offset_t>
23 auto constexpr get_node_alignment()
24 {
25  auto total = index_type(std::max(sizeof(threshold_t), sizeof(index_t)) +
26  sizeof(metadata_storage_t) + sizeof(offset_t));
27  auto result = index_type{8};
28  if (total > result) { result = index_type{16}; }
29  if (total > result) { result = index_type{32}; }
30  if (total > result) { result = index_type{64}; }
31  if (total > result) { result = index_type{128}; }
32  if (total > result) { result = index_type{256}; }
33  if (total > result) { result = total; }
34  return result;
35 }
36 
37 } // namespace detail
38 
39 /* @brief A single node in a forest model
40  *
41  * Note that this implementation includes NO error checking for poorly-chosen
42  * template types. If the types are not large enough to hold the required data,
43  * an incorrect node will be silently constructed. Error checking occurs
44  * instead at the time of construction of the entire forest.
45  *
46  * @tparam layout_v The layout for nodes within the forest
47  *
48  * @tparam threshold_t The type used as a threshold for evaluating a non-leaf
49  * node or (when possible) the output of a leaf node. For non-categorical
50  * nodes, if an input value is less than this threshold, the node evaluates to
51  * true. For leaf nodes, output values will only be stored as this type if it
52  * matches the leaf output type expected by the forest. Typically, this type is
53  * either float or double.
54  *
55  * @tparam index_t The type used as an index to the output data for leaf nodes,
56  * or to the categorical set for a categorical non-leaf node. This type should
57  * be large enough to index the entire array of output data or categorical sets
58  * stored in the forest. Typically, this type is either uint32_t or uint64_t.
59  * Smaller types offer no benefit, since this value is stored in a union with
60  * threshold_t, which is at least 32 bits.
61  *
62  * @tparam metadata_storage_t An unsigned integral type used for a bit-wise
63  * representation of metadata about this node. The first three bits encode
64  * whether or not this is a leaf node, whether or not we should default to the
65  * more distant child in case of missing values, and whether or not this node
66  * is categorical. The remaining bits are used to encode the feature index for
67  * this node. Thus, uint8_t may be used for 2**(8 - 3) = 32 or fewer features,
68  * uint16_t for 2**(16 - 3) = 8192 or fewer, and uint32_t for 536870912 or
69  * fewer features.
70  *
71  * @tparam offset_t An integral type used to indicate the offset from
72  * this node to its most distant child. This type must be large enough to store
73  * the largest such offset in the forest model.
74  */
75 template <tree_layout layout_v,
76  typename threshold_t,
77  typename index_t,
78  typename metadata_storage_t,
79  typename offset_t>
80 struct alignas(detail::get_node_alignment<threshold_t, index_t, metadata_storage_t, offset_t>())
81  node {
82  // @brief An alias for layout_v
83  auto constexpr static const layout = layout_v;
84  // @brief An alias for threshold_t
85  using threshold_type = threshold_t;
86  // @brief An alias for index_t
87  using index_type = index_t;
88  /* @brief A union to hold either a threshold value or index
89  *
90  * All nodes will need EITHER a threshold value, an output value, OR an index
91  * to data elsewhere that wil be used either for evaluating the node (e.g. an
92  * index to a categorical set) or creating an output (e.g. an index to vector
93  * leaf output). This union allows us to store either of these values without
94  * using additional space for the unused value.
95  */
96  union value_type {
97  threshold_t value;
98  index_t index;
99  };
101  using metadata_storage_type = metadata_storage_t;
103  using offset_type = offset_t;
104 
105  // TODO(wphicks): Add custom type to ensure given child offset is at least
106  // one
107 #pragma GCC diagnostic push
108 #pragma GCC diagnostic ignored "-Wnarrowing"
110  bool is_leaf_node = true,
111  bool default_to_distant_child = false,
112  bool is_categorical_node = false,
114  offset_type distant_child_offset = offset_type{})
115  : aligned_data{
116  .inner_data = {
117  {.value = value},
118  distant_child_offset,
119  construct_metadata(is_leaf_node, default_to_distant_child, is_categorical_node, feature)}}
120  {
121  }
122 
124  bool is_leaf_node = true,
125  bool default_to_distant_child = false,
126  bool is_categorical_node = false,
128  offset_type distant_child_offset = offset_type{})
129  : aligned_data{
130  .inner_data = {
131  {.index = index},
132  distant_child_offset,
133  construct_metadata(is_leaf_node, default_to_distant_child, is_categorical_node, feature)}}
134  {
135  }
136 #pragma GCC diagnostic pop
137 
138  /* The index of the feature for this node */
139  HOST DEVICE auto constexpr feature_index() const
140  {
141  return aligned_data.inner_data.metadata & FEATURE_MASK;
142  }
143  /* Whether or not this node is a leaf node */
144  HOST DEVICE auto constexpr is_leaf() const
145  {
146  return !bool(aligned_data.inner_data.distant_offset);
147  }
148  /* Whether or not to default to distant child in case of missing values */
149  HOST DEVICE auto constexpr default_distant() const
150  {
151  return bool(aligned_data.inner_data.metadata & DEFAULT_DISTANT_MASK);
152  }
153  /* Whether or not this node is a categorical node */
154  HOST DEVICE auto constexpr is_categorical() const
155  {
156  return bool(aligned_data.inner_data.metadata & CATEGORICAL_MASK);
157  }
158  /* The offset to the child of this node if it evaluates to given condition */
159  HOST DEVICE auto constexpr child_offset(bool condition) const
160  {
161  if constexpr (layout == tree_layout::depth_first) {
162  return offset_type{1} + condition * (aligned_data.inner_data.distant_offset - offset_type{1});
163  } else if constexpr (layout == tree_layout::breadth_first ||
165  return condition * offset_type{1} + (aligned_data.inner_data.distant_offset - offset_type{1});
166  } else {
167  static_assert(layout == tree_layout::depth_first);
168  }
169  }
170  /* The threshold value for this node */
171  HOST DEVICE auto constexpr threshold() const
172  {
173  return aligned_data.inner_data.stored_value.value;
174  }
175 
176  /* The index value for this node */
177  HOST DEVICE auto const& index() const { return aligned_data.inner_data.stored_value.index; }
178  /* The output value for this node
179  *
180  * @tparam output_t The expected output type for this node.
181  */
182  template <bool has_vector_leaves>
183  HOST DEVICE auto constexpr output() const
184  {
185  if constexpr (has_vector_leaves) {
186  return aligned_data.inner_data.stored_value.index;
187  } else {
188  return aligned_data.inner_data.stored_value.value;
189  }
190  }
191 
192  private:
193  /* Define all bit masks required to extract information from the stored
194  * metadata. The first bit tells us whether or not this is a leaf node, the
195  * second tells us whether or not we should default to the distant child in
196  * the case of a missing value, and the third tells us whether or not this is
197  * a categorical node. The remaining bits indicate the index of the feature
198  * for this node */
199  auto constexpr static const LEAF_BIT =
201  auto constexpr static const LEAF_MASK = metadata_storage_type(1 << LEAF_BIT);
202  auto constexpr static const DEFAULT_DISTANT_BIT = metadata_storage_type(LEAF_BIT - 1);
203  auto constexpr static const DEFAULT_DISTANT_MASK =
204  metadata_storage_type(1 << DEFAULT_DISTANT_BIT);
205  auto constexpr static const CATEGORICAL_BIT = metadata_storage_type(DEFAULT_DISTANT_BIT - 1);
206  auto constexpr static const CATEGORICAL_MASK = metadata_storage_type(1 << CATEGORICAL_BIT);
207  auto constexpr static const FEATURE_MASK =
208  metadata_storage_type(~(LEAF_MASK | DEFAULT_DISTANT_MASK | CATEGORICAL_MASK));
209 
210  // Helper function for bit packing with the above masks
211  auto static constexpr construct_metadata(bool is_leaf_node = true,
212  bool default_to_distant_child = false,
213  bool is_categorical_node = false,
215  {
216  return metadata_storage_type(
217  (is_leaf_node << LEAF_BIT) + (default_to_distant_child << DEFAULT_DISTANT_BIT) +
218  (is_categorical_node << CATEGORICAL_BIT) + (feature & FEATURE_MASK));
219  }
220 
221  auto static constexpr const byte_size =
222  detail::get_node_alignment<threshold_t, index_t, metadata_storage_t, offset_t>();
223 
224  struct inner_data_type {
225  value_type stored_value;
226  // TODO (wphicks): It may be possible to store both of the following together
227  // to save bytes
228  offset_type distant_offset;
229  metadata_storage_type metadata;
230  };
231  union aligned_data_type {
232  inner_data_type inner_data;
233  char spacer_data[byte_size];
234  };
235 
236  aligned_data_type aligned_data;
237 };
238 
239 } // namespace fil
240 } // namespace ML
#define DEVICE
Definition: gpu_support.hpp:24
#define HOST
Definition: gpu_support.hpp:23
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
constexpr auto get_node_alignment()
Definition: node.hpp:23
tree_layout
Definition: tree_layout.hpp:8
uint32_t index_type
Definition: index_type.hpp:9
Definition: dbscan.hpp:18
Definition: node.hpp:81
HOST DEVICE constexpr auto output() const
Definition: node.hpp:183
constexpr static auto const layout
Definition: node.hpp:83
threshold_t threshold_type
Definition: node.hpp:85
HOST DEVICE constexpr auto is_categorical() const
Definition: node.hpp:154
HOST constexpr DEVICE node(index_type index, bool is_leaf_node=true, bool default_to_distant_child=false, bool is_categorical_node=false, metadata_storage_type feature=metadata_storage_type{}, offset_type distant_child_offset=offset_type{})
Definition: node.hpp:123
offset_t offset_type
An alias for offset_t.
Definition: node.hpp:103
HOST constexpr DEVICE node(threshold_type value=threshold_type{}, bool is_leaf_node=true, bool default_to_distant_child=false, bool is_categorical_node=false, metadata_storage_type feature=metadata_storage_type{}, offset_type distant_child_offset=offset_type{})
Definition: node.hpp:109
HOST DEVICE auto const & index() const
Definition: node.hpp:177
metadata_storage_t metadata_storage_type
An alias for metadata_storage_t.
Definition: node.hpp:101
index_t index_type
Definition: node.hpp:87
HOST DEVICE constexpr auto feature_index() const
Definition: node.hpp:139
HOST DEVICE constexpr auto is_leaf() const
Definition: node.hpp:144
HOST DEVICE constexpr auto child_offset(bool condition) const
Definition: node.hpp:159
HOST DEVICE constexpr auto threshold() const
Definition: node.hpp:171
HOST DEVICE constexpr auto default_distant() const
Definition: node.hpp:149
Definition: node.hpp:96
index_t index
Definition: node.hpp:98
threshold_t value
Definition: node.hpp:97