Loading [MathJax]/extensions/tex2jax.js
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
traversal_forest.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2024-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
19 
20 #include <cstddef>
21 #include <queue>
22 #include <stack>
23 
24 namespace ML {
25 namespace experimental {
26 namespace forest {
27 
28 namespace detail {
31 template <forest_order order, typename T>
34  std::conditional_t<order == forest_order::depth_first, std::stack<T>, std::queue<T>>;
35  void add(T const& val) { data_.push(val); }
36  void add(T const& hot, T const& distant)
37  {
38  if constexpr (order == forest_order::depth_first) {
39  data_.push(distant);
40  data_.push(hot);
41  } else {
42  data_.push(hot);
43  data_.push(distant);
44  }
45  }
46  auto next()
47  {
48  if constexpr (std::is_same_v<backing_container_t, std::stack<T>>) {
49  auto result = data_.top();
50  data_.pop();
51  return result;
52  } else {
53  auto result = data_.front();
54  data_.pop();
55  return result;
56  }
57  }
58  auto peek()
59  {
60  if constexpr (std::is_same_v<backing_container_t, std::stack<T>>) {
61  return data_.top();
62  } else {
63  return data_.front();
64  }
65  }
66  [[nodiscard]] auto empty() { return data_.empty(); }
67  auto size() { return data_.size(); }
68 
69  private:
70  backing_container_t data_;
71 };
72 } // namespace detail
73  //
74 
75 template <typename node_t = traversal_node<std::size_t>, typename tree_id_t = std::size_t>
77  using node_type = node_t;
78  using node_id_type = typename node_type::id_type;
79  using tree_id_type = tree_id_t;
80  using node_uid_type = std::pair<tree_id_type, node_id_type>;
81  using index_type = std::size_t;
82 
83  virtual node_type get_node(tree_id_type tree_id, node_id_type node_id) const = 0;
84 
85  traversal_forest(std::vector<node_uid_type>&& root_node_uids) : root_node_uids_{root_node_uids} {}
86 
87  template <forest_order order, typename lambda_t>
88  void for_each(lambda_t&& lambda) const
89  {
90  auto to_be_visited = detail::traversal_container<
91  order,
92  std::conditional_t<order == forest_order::layered_children_segregated ||
94  // Layered traversals can track current depth without storing
95  // alongside each node. This can also be done with depth-first
96  // traversals, but we exchange memory footprint of the depth-first
97  // case for simplified code. By storing depth for both depth-first
98  // and breadth-first, we can make the code for each identical.
100  std::pair<node_uid_type, index_type>>>{};
101  auto parent_indices = detail::traversal_container<order, index_type>{};
102  auto cur_index = index_type{};
103  if constexpr (order == forest_order::depth_first || order == forest_order::breadth_first) {
104  for (auto const& root_node_uid : root_node_uids_) {
105  to_be_visited.add(std::make_pair(root_node_uid, std::size_t{}));
106  parent_indices.add(cur_index);
107  while (!to_be_visited.empty()) {
108  auto [node_uid, depth] = to_be_visited.next();
109  auto parent_index = parent_indices.next();
110  auto node = get_node(node_uid);
111  lambda(node_uid.first, node, depth, parent_index);
112  if (!node.is_leaf()) {
113  auto hot_uid = std::make_pair(std::make_pair(node_uid.first, node.hot_child()),
114  depth + index_type{1});
115  auto distant_uid = std::make_pair(std::make_pair(node_uid.first, node.distant_child()),
116  depth + index_type{1});
117  to_be_visited.add(hot_uid, distant_uid);
118  parent_indices.add(cur_index, cur_index);
119  }
120  ++cur_index;
121  }
122  }
123  } else if constexpr (order == forest_order::layered_children_segregated) {
124  for (auto const& root_node_uid : root_node_uids_) {
125  to_be_visited.add(root_node_uid);
126  parent_indices.add(cur_index++);
127  }
128  cur_index = index_type{};
129  auto depth = index_type{};
130  while (!to_be_visited.empty()) {
131  auto layer_node_uids = std::vector<node_uid_type>{};
132  auto layer_parent_indices = std::vector<index_type>{};
133  while (!to_be_visited.empty()) {
134  layer_node_uids.push_back(to_be_visited.next());
135  layer_parent_indices.push_back(parent_indices.next());
136  }
137  for (auto layer_index = index_type{}; layer_index < layer_node_uids.size(); ++layer_index) {
138  auto node_uid = layer_node_uids[layer_index];
139  auto parent_index = layer_parent_indices[layer_index];
140  auto node = get_node(node_uid);
141  lambda(node_uid.first, node, depth, parent_index);
142  if (!node.is_leaf()) {
143  auto hot_uid = std::make_pair(node_uid.first, node.hot_child());
144  to_be_visited.add(hot_uid);
145  parent_indices.add(cur_index);
146  }
147  ++cur_index;
148  }
149  // Reset cur_index before iterating through distant nodes
150  cur_index -= layer_node_uids.size();
151  for (auto layer_index = index_type{}; layer_index < layer_node_uids.size(); ++layer_index) {
152  auto node_uid = layer_node_uids[layer_index];
153  auto node = get_node(node_uid);
154  if (!node.is_leaf()) {
155  auto distant_uid = std::make_pair(node_uid.first, node.distant_child());
156  to_be_visited.add(distant_uid);
157  parent_indices.add(cur_index);
158  }
159  ++cur_index;
160  }
161  ++depth;
162  }
163  } else if constexpr (order == forest_order::layered_children_together) {
164  for (auto const& root_node_uid : root_node_uids_) {
165  to_be_visited.add(root_node_uid);
166  parent_indices.add(cur_index++);
167  }
168  cur_index = index_type{};
169  auto depth = index_type{};
170  while (!to_be_visited.empty()) {
171  auto layer_node_uids = std::vector<node_uid_type>{};
172  auto layer_parent_indices = std::vector<index_type>{};
173  while (!to_be_visited.empty()) {
174  layer_node_uids.push_back(to_be_visited.next());
175  layer_parent_indices.push_back(parent_indices.next());
176  }
177  for (auto layer_index = index_type{}; layer_index < layer_node_uids.size(); ++layer_index) {
178  auto node_uid = layer_node_uids[layer_index];
179  auto parent_index = layer_parent_indices[layer_index];
180  auto node = get_node(node_uid);
181  lambda(node_uid.first, node, depth, parent_index);
182  if (!node.is_leaf()) {
183  auto hot_uid = std::make_pair(node_uid.first, node.hot_child());
184  auto distant_uid = std::make_pair(node_uid.first, node.distant_child());
185  to_be_visited.add(hot_uid, distant_uid);
186  parent_indices.add(cur_index, cur_index);
187  }
188  ++cur_index;
189  }
190  ++depth;
191  }
192  }
193  }
194 
195  private:
196  auto get_node(node_uid_type node_uid) const { return get_node(node_uid.first, node_uid.second); }
197 
198  std::vector<node_uid_type> root_node_uids_{};
199 };
200 
201 } // namespace forest
202 } // namespace experimental
203 } // namespace ML
Definition: dbscan.hpp:30
auto size()
Definition: traversal_forest.hpp:67
void add(T const &val)
Definition: traversal_forest.hpp:35
auto empty()
Definition: traversal_forest.hpp:66
void add(T const &hot, T const &distant)
Definition: traversal_forest.hpp:36
auto next()
Definition: traversal_forest.hpp:46
auto peek()
Definition: traversal_forest.hpp:58
std::conditional_t< order==forest_order::depth_first, std::stack< T >, std::queue< T > > backing_container_t
Definition: traversal_forest.hpp:34
Definition: traversal_forest.hpp:76
node_t node_type
Definition: traversal_forest.hpp:77
typename node_type::id_type node_id_type
Definition: traversal_forest.hpp:78
void for_each(lambda_t &&lambda) const
Definition: traversal_forest.hpp:88
std::pair< tree_id_type, node_id_type > node_uid_type
Definition: traversal_forest.hpp:80
tree_id_t tree_id_type
Definition: traversal_forest.hpp:79
std::size_t index_type
Definition: traversal_forest.hpp:81
traversal_forest(std::vector< node_uid_type > &&root_node_uids)
Definition: traversal_forest.hpp:85
virtual node_type get_node(tree_id_type tree_id, node_id_type node_id) const =0