forest.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
9 
10 #include <stddef.h>
11 
12 #include <type_traits>
13 
14 namespace ML {
15 namespace fil {
16 
17 /* A collection of trees which together form a forest model
18  */
19 template <tree_layout layout_v,
20  typename threshold_t,
21  typename index_t,
22  typename metadata_storage_t,
23  typename offset_t>
24 struct forest {
26  using io_type = threshold_t;
27  template <typename vector_output_t>
28  using raw_output_type = std::conditional_t<!std::is_same_v<vector_output_t, std::nullptr_t>,
29  std::remove_pointer_t<vector_output_t>,
30  typename node_type::threshold_type>;
31 
32  HOST DEVICE forest(node_type* forest_nodes,
33  index_type* forest_root_indexes,
34  index_type* node_id_mapping,
35  io_type* bias,
36  index_type num_trees,
38  : nodes_{forest_nodes},
39  root_node_indexes_{forest_root_indexes},
40  node_id_mapping_{node_id_mapping},
41  bias_{bias},
42  num_trees_{num_trees},
43  num_outputs_{num_outputs}
44  {
45  }
46 
47  /* Return pointer to the root node of the indicated tree */
48  HOST DEVICE auto* get_tree_root(index_type tree_index) const
49  {
50  return nodes_ + root_node_indexes_[tree_index];
51  }
52 
53  /* Return pointer to the mapping from internal node IDs to final node ID outputs.
54  * Only used when infer_type == infer_kind::leaf_id */
55  HOST DEVICE const auto* get_node_id_mapping() const { return node_id_mapping_; }
56 
57  /* Return pointer to the bias term */
58  HOST DEVICE const auto* bias() const { return bias_; }
59 
60  /* Return the number of trees in this forest */
61  HOST DEVICE auto tree_count() const { return num_trees_; }
62 
63  /* Return the number of outputs per row for default evaluation of this
64  * forest */
65  HOST DEVICE auto num_outputs() const { return num_outputs_; }
66 
67  private:
68  node_type* nodes_;
69  index_type* root_node_indexes_;
70  index_type* node_id_mapping_;
71  io_type* bias_;
72  index_type num_trees_;
73  index_type num_outputs_;
74 };
75 
76 } // namespace fil
77 } // namespace ML
#define DEVICE
Definition: gpu_support.hpp:24
#define HOST
Definition: gpu_support.hpp:23
tree_layout
Definition: tree_layout.hpp:8
uint32_t index_type
Definition: index_type.hpp:9
Definition: dbscan.hpp:18
Definition: forest.hpp:24
HOST DEVICE auto num_outputs() const
Definition: forest.hpp:65
HOST DEVICE auto * get_tree_root(index_type tree_index) const
Definition: forest.hpp:48
threshold_t io_type
Definition: forest.hpp:26
std::conditional_t<!std::is_same_v< vector_output_t, std::nullptr_t >, std::remove_pointer_t< vector_output_t >, typename node_type::threshold_type > raw_output_type
Definition: forest.hpp:30
HOST DEVICE const auto * get_node_id_mapping() const
Definition: forest.hpp:55
HOST DEVICE auto tree_count() const
Definition: forest.hpp:61
HOST DEVICE const auto * bias() const
Definition: forest.hpp:58
node< layout_v, threshold_t, index_t, metadata_storage_t, offset_t > node_type
Definition: forest.hpp:25
HOST DEVICE forest(node_type *forest_nodes, index_type *forest_root_indexes, index_type *node_id_mapping, io_type *bias, index_type num_trees, index_type num_outputs)
Definition: forest.hpp:32
Definition: node.hpp:81
threshold_t threshold_type
Definition: node.hpp:85