traversal_forest.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
8 
9 #include <cstddef>
10 #include <queue>
11 #include <stack>
12 
13 namespace ML {
14 namespace forest {
15 
16 namespace detail {
19 template <forest_order order, typename T>
22  std::conditional_t<order == forest_order::depth_first, std::stack<T>, std::queue<T>>;
23  void add(T const& val) { data_.push(val); }
24  void add(T const& hot, T const& distant)
25  {
26  if constexpr (order == forest_order::depth_first) {
27  data_.push(distant);
28  data_.push(hot);
29  } else {
30  data_.push(hot);
31  data_.push(distant);
32  }
33  }
34  auto next()
35  {
36  if constexpr (std::is_same_v<backing_container_t, std::stack<T>>) {
37  auto result = data_.top();
38  data_.pop();
39  return result;
40  } else {
41  auto result = data_.front();
42  data_.pop();
43  return result;
44  }
45  }
46  auto peek()
47  {
48  if constexpr (std::is_same_v<backing_container_t, std::stack<T>>) {
49  return data_.top();
50  } else {
51  return data_.front();
52  }
53  }
54  [[nodiscard]] auto empty() { return data_.empty(); }
55  auto size() { return data_.size(); }
56 
57  private:
58  backing_container_t data_;
59 };
60 } // namespace detail
61  //
62 
63 template <typename node_t = traversal_node<std::size_t>, typename tree_id_t = std::size_t>
65  using node_type = node_t;
66  using node_id_type = typename node_type::id_type;
67  using tree_id_type = tree_id_t;
68  using node_uid_type = std::pair<tree_id_type, node_id_type>;
69  using index_type = std::size_t;
70 
71  virtual node_type get_node(tree_id_type tree_id, node_id_type node_id) const = 0;
72 
73  traversal_forest(std::vector<node_uid_type>&& root_node_uids) : root_node_uids_{root_node_uids} {}
74 
75  template <forest_order order, typename lambda_t>
76  void for_each(lambda_t&& lambda) const
77  {
78  auto to_be_visited = detail::traversal_container<
79  order,
80  std::conditional_t<order == forest_order::layered_children_segregated ||
82  // Layered traversals can track current depth without storing
83  // alongside each node. This can also be done with depth-first
84  // traversals, but we exchange memory footprint of the depth-first
85  // case for simplified code. By storing depth for both depth-first
86  // and breadth-first, we can make the code for each identical.
88  std::pair<node_uid_type, index_type>>>{};
89  auto parent_indices = detail::traversal_container<order, index_type>{};
90  auto cur_index = index_type{};
91  if constexpr (order == forest_order::depth_first || order == forest_order::breadth_first) {
92  for (auto const& root_node_uid : root_node_uids_) {
93  to_be_visited.add(std::make_pair(root_node_uid, std::size_t{}));
94  parent_indices.add(cur_index);
95  while (!to_be_visited.empty()) {
96  auto [node_uid, depth] = to_be_visited.next();
97  auto parent_index = parent_indices.next();
98  auto node = get_node(node_uid);
99  lambda(node_uid.first, node, depth, parent_index);
100  if (!node.is_leaf()) {
101  auto hot_uid = std::make_pair(std::make_pair(node_uid.first, node.hot_child()),
102  depth + index_type{1});
103  auto distant_uid = std::make_pair(std::make_pair(node_uid.first, node.distant_child()),
104  depth + index_type{1});
105  to_be_visited.add(hot_uid, distant_uid);
106  parent_indices.add(cur_index, cur_index);
107  }
108  ++cur_index;
109  }
110  }
111  } else if constexpr (order == forest_order::layered_children_segregated) {
112  for (auto const& root_node_uid : root_node_uids_) {
113  to_be_visited.add(root_node_uid);
114  parent_indices.add(cur_index++);
115  }
116  cur_index = index_type{};
117  auto depth = index_type{};
118  while (!to_be_visited.empty()) {
119  auto layer_node_uids = std::vector<node_uid_type>{};
120  auto layer_parent_indices = std::vector<index_type>{};
121  while (!to_be_visited.empty()) {
122  layer_node_uids.push_back(to_be_visited.next());
123  layer_parent_indices.push_back(parent_indices.next());
124  }
125  for (auto layer_index = index_type{}; layer_index < layer_node_uids.size(); ++layer_index) {
126  auto node_uid = layer_node_uids[layer_index];
127  auto parent_index = layer_parent_indices[layer_index];
128  auto node = get_node(node_uid);
129  lambda(node_uid.first, node, depth, parent_index);
130  if (!node.is_leaf()) {
131  auto hot_uid = std::make_pair(node_uid.first, node.hot_child());
132  to_be_visited.add(hot_uid);
133  parent_indices.add(cur_index);
134  }
135  ++cur_index;
136  }
137  // Reset cur_index before iterating through distant nodes
138  cur_index -= layer_node_uids.size();
139  for (auto layer_index = index_type{}; layer_index < layer_node_uids.size(); ++layer_index) {
140  auto node_uid = layer_node_uids[layer_index];
141  auto node = get_node(node_uid);
142  if (!node.is_leaf()) {
143  auto distant_uid = std::make_pair(node_uid.first, node.distant_child());
144  to_be_visited.add(distant_uid);
145  parent_indices.add(cur_index);
146  }
147  ++cur_index;
148  }
149  ++depth;
150  }
151  } else if constexpr (order == forest_order::layered_children_together) {
152  for (auto const& root_node_uid : root_node_uids_) {
153  to_be_visited.add(root_node_uid);
154  parent_indices.add(cur_index++);
155  }
156  cur_index = index_type{};
157  auto depth = index_type{};
158  while (!to_be_visited.empty()) {
159  auto layer_node_uids = std::vector<node_uid_type>{};
160  auto layer_parent_indices = std::vector<index_type>{};
161  while (!to_be_visited.empty()) {
162  layer_node_uids.push_back(to_be_visited.next());
163  layer_parent_indices.push_back(parent_indices.next());
164  }
165  for (auto layer_index = index_type{}; layer_index < layer_node_uids.size(); ++layer_index) {
166  auto node_uid = layer_node_uids[layer_index];
167  auto parent_index = layer_parent_indices[layer_index];
168  auto node = get_node(node_uid);
169  lambda(node_uid.first, node, depth, parent_index);
170  if (!node.is_leaf()) {
171  auto hot_uid = std::make_pair(node_uid.first, node.hot_child());
172  auto distant_uid = std::make_pair(node_uid.first, node.distant_child());
173  to_be_visited.add(hot_uid, distant_uid);
174  parent_indices.add(cur_index, cur_index);
175  }
176  ++cur_index;
177  }
178  ++depth;
179  }
180  }
181  }
182 
183  private:
184  auto get_node(node_uid_type node_uid) const { return get_node(node_uid.first, node_uid.second); }
185 
186  std::vector<node_uid_type> root_node_uids_{};
187 };
188 
189 } // namespace forest
190 } // namespace ML
Definition: dbscan.hpp:18
Definition: traversal_forest.hpp:20
auto peek()
Definition: traversal_forest.hpp:46
auto size()
Definition: traversal_forest.hpp:55
void add(T const &hot, T const &distant)
Definition: traversal_forest.hpp:24
void add(T const &val)
Definition: traversal_forest.hpp:23
auto empty()
Definition: traversal_forest.hpp:54
std::conditional_t< order==forest_order::depth_first, std::stack< T >, std::queue< T > > backing_container_t
Definition: traversal_forest.hpp:22
auto next()
Definition: traversal_forest.hpp:34
Definition: traversal_forest.hpp:64
tree_id_t tree_id_type
Definition: traversal_forest.hpp:67
typename node_type::id_type node_id_type
Definition: traversal_forest.hpp:66
std::pair< tree_id_type, node_id_type > node_uid_type
Definition: traversal_forest.hpp:68
std::size_t index_type
Definition: traversal_forest.hpp:69
traversal_forest(std::vector< node_uid_type > &&root_node_uids)
Definition: traversal_forest.hpp:73
void for_each(lambda_t &&lambda) const
Definition: traversal_forest.hpp:76
node_t node_type
Definition: traversal_forest.hpp:65
virtual node_type get_node(tree_id_type tree_id, node_id_type node_id) const =0