traversal_forest.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2024-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
19 
20 #include <cstddef>
21 #include <queue>
22 #include <stack>
23 
24 namespace ML {
25 namespace forest {
26 
27 namespace detail {
30 template <forest_order order, typename T>
33  std::conditional_t<order == forest_order::depth_first, std::stack<T>, std::queue<T>>;
34  void add(T const& val) { data_.push(val); }
35  void add(T const& hot, T const& distant)
36  {
37  if constexpr (order == forest_order::depth_first) {
38  data_.push(distant);
39  data_.push(hot);
40  } else {
41  data_.push(hot);
42  data_.push(distant);
43  }
44  }
45  auto next()
46  {
47  if constexpr (std::is_same_v<backing_container_t, std::stack<T>>) {
48  auto result = data_.top();
49  data_.pop();
50  return result;
51  } else {
52  auto result = data_.front();
53  data_.pop();
54  return result;
55  }
56  }
57  auto peek()
58  {
59  if constexpr (std::is_same_v<backing_container_t, std::stack<T>>) {
60  return data_.top();
61  } else {
62  return data_.front();
63  }
64  }
65  [[nodiscard]] auto empty() { return data_.empty(); }
66  auto size() { return data_.size(); }
67 
68  private:
69  backing_container_t data_;
70 };
71 } // namespace detail
72  //
73 
74 template <typename node_t = traversal_node<std::size_t>, typename tree_id_t = std::size_t>
76  using node_type = node_t;
77  using node_id_type = typename node_type::id_type;
78  using tree_id_type = tree_id_t;
79  using node_uid_type = std::pair<tree_id_type, node_id_type>;
80  using index_type = std::size_t;
81 
82  virtual node_type get_node(tree_id_type tree_id, node_id_type node_id) const = 0;
83 
84  traversal_forest(std::vector<node_uid_type>&& root_node_uids) : root_node_uids_{root_node_uids} {}
85 
86  template <forest_order order, typename lambda_t>
87  void for_each(lambda_t&& lambda) const
88  {
89  auto to_be_visited = detail::traversal_container<
90  order,
91  std::conditional_t<order == forest_order::layered_children_segregated ||
93  // Layered traversals can track current depth without storing
94  // alongside each node. This can also be done with depth-first
95  // traversals, but we exchange memory footprint of the depth-first
96  // case for simplified code. By storing depth for both depth-first
97  // and breadth-first, we can make the code for each identical.
99  std::pair<node_uid_type, index_type>>>{};
100  auto parent_indices = detail::traversal_container<order, index_type>{};
101  auto cur_index = index_type{};
102  if constexpr (order == forest_order::depth_first || order == forest_order::breadth_first) {
103  for (auto const& root_node_uid : root_node_uids_) {
104  to_be_visited.add(std::make_pair(root_node_uid, std::size_t{}));
105  parent_indices.add(cur_index);
106  while (!to_be_visited.empty()) {
107  auto [node_uid, depth] = to_be_visited.next();
108  auto parent_index = parent_indices.next();
109  auto node = get_node(node_uid);
110  lambda(node_uid.first, node, depth, parent_index);
111  if (!node.is_leaf()) {
112  auto hot_uid = std::make_pair(std::make_pair(node_uid.first, node.hot_child()),
113  depth + index_type{1});
114  auto distant_uid = std::make_pair(std::make_pair(node_uid.first, node.distant_child()),
115  depth + index_type{1});
116  to_be_visited.add(hot_uid, distant_uid);
117  parent_indices.add(cur_index, cur_index);
118  }
119  ++cur_index;
120  }
121  }
122  } else if constexpr (order == forest_order::layered_children_segregated) {
123  for (auto const& root_node_uid : root_node_uids_) {
124  to_be_visited.add(root_node_uid);
125  parent_indices.add(cur_index++);
126  }
127  cur_index = index_type{};
128  auto depth = index_type{};
129  while (!to_be_visited.empty()) {
130  auto layer_node_uids = std::vector<node_uid_type>{};
131  auto layer_parent_indices = std::vector<index_type>{};
132  while (!to_be_visited.empty()) {
133  layer_node_uids.push_back(to_be_visited.next());
134  layer_parent_indices.push_back(parent_indices.next());
135  }
136  for (auto layer_index = index_type{}; layer_index < layer_node_uids.size(); ++layer_index) {
137  auto node_uid = layer_node_uids[layer_index];
138  auto parent_index = layer_parent_indices[layer_index];
139  auto node = get_node(node_uid);
140  lambda(node_uid.first, node, depth, parent_index);
141  if (!node.is_leaf()) {
142  auto hot_uid = std::make_pair(node_uid.first, node.hot_child());
143  to_be_visited.add(hot_uid);
144  parent_indices.add(cur_index);
145  }
146  ++cur_index;
147  }
148  // Reset cur_index before iterating through distant nodes
149  cur_index -= layer_node_uids.size();
150  for (auto layer_index = index_type{}; layer_index < layer_node_uids.size(); ++layer_index) {
151  auto node_uid = layer_node_uids[layer_index];
152  auto node = get_node(node_uid);
153  if (!node.is_leaf()) {
154  auto distant_uid = std::make_pair(node_uid.first, node.distant_child());
155  to_be_visited.add(distant_uid);
156  parent_indices.add(cur_index);
157  }
158  ++cur_index;
159  }
160  ++depth;
161  }
162  } else if constexpr (order == forest_order::layered_children_together) {
163  for (auto const& root_node_uid : root_node_uids_) {
164  to_be_visited.add(root_node_uid);
165  parent_indices.add(cur_index++);
166  }
167  cur_index = index_type{};
168  auto depth = index_type{};
169  while (!to_be_visited.empty()) {
170  auto layer_node_uids = std::vector<node_uid_type>{};
171  auto layer_parent_indices = std::vector<index_type>{};
172  while (!to_be_visited.empty()) {
173  layer_node_uids.push_back(to_be_visited.next());
174  layer_parent_indices.push_back(parent_indices.next());
175  }
176  for (auto layer_index = index_type{}; layer_index < layer_node_uids.size(); ++layer_index) {
177  auto node_uid = layer_node_uids[layer_index];
178  auto parent_index = layer_parent_indices[layer_index];
179  auto node = get_node(node_uid);
180  lambda(node_uid.first, node, depth, parent_index);
181  if (!node.is_leaf()) {
182  auto hot_uid = std::make_pair(node_uid.first, node.hot_child());
183  auto distant_uid = std::make_pair(node_uid.first, node.distant_child());
184  to_be_visited.add(hot_uid, distant_uid);
185  parent_indices.add(cur_index, cur_index);
186  }
187  ++cur_index;
188  }
189  ++depth;
190  }
191  }
192  }
193 
194  private:
195  auto get_node(node_uid_type node_uid) const { return get_node(node_uid.first, node_uid.second); }
196 
197  std::vector<node_uid_type> root_node_uids_{};
198 };
199 
200 } // namespace forest
201 } // namespace ML
Definition: dbscan.hpp:29
Definition: traversal_forest.hpp:31
auto peek()
Definition: traversal_forest.hpp:57
auto size()
Definition: traversal_forest.hpp:66
void add(T const &hot, T const &distant)
Definition: traversal_forest.hpp:35
void add(T const &val)
Definition: traversal_forest.hpp:34
auto empty()
Definition: traversal_forest.hpp:65
std::conditional_t< order==forest_order::depth_first, std::stack< T >, std::queue< T > > backing_container_t
Definition: traversal_forest.hpp:33
auto next()
Definition: traversal_forest.hpp:45
Definition: traversal_forest.hpp:75
tree_id_t tree_id_type
Definition: traversal_forest.hpp:78
typename node_type::id_type node_id_type
Definition: traversal_forest.hpp:77
std::pair< tree_id_type, node_id_type > node_uid_type
Definition: traversal_forest.hpp:79
std::size_t index_type
Definition: traversal_forest.hpp:80
traversal_forest(std::vector< node_uid_type > &&root_node_uids)
Definition: traversal_forest.hpp:84
void for_each(lambda_t &&lambda) const
Definition: traversal_forest.hpp:87
node_t node_type
Definition: traversal_forest.hpp:76
virtual node_type get_node(tree_id_type tree_id, node_id_type node_id) const =0