specialization_types.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
8 
9 #include <cstddef>
10 #include <cstdint>
11 #include <type_traits>
12 #include <variant>
13 
14 namespace ML {
15 namespace fil {
16 namespace detail {
17 
18 /*
19  * A template used solely to help manage the types which will be compiled in
20  * standard cuML FIL
21  *
22  * The relatively simple and human-readable template parameters of this
23  * template are translated into the specific types and values required
24  * to instantiate more complex templates and compile-time checks.
25  *
26  * @tparam layout_v The layout of trees within a model
27  * @tparam double_precision Whether this model should use double-precision
28  * for floating-point evaluation and 64-bit integers for indexes
29  * @tparam large_trees Whether this forest expects more than 2**(16 -3) - 1 =
30  * 8191 features or contains nodes whose child is offset more than 2**16 - 1 = 65535 nodes away.
31  */
32 template <tree_layout layout_v, bool double_precision, bool large_trees>
34  /* The node threshold type to be used based on the template parameters
35  */
36  using threshold_type = std::conditional_t<double_precision, double, float>;
37  /* The type required for specifying indexes to vector leaf outputs or
38  * non-local categorical data.
39  */
40  using index_type = std::conditional_t<double_precision, std::uint64_t, std::uint32_t>;
41  /* The type used to provide metadata storage for nodes */
42  using metadata_type = std::conditional_t<large_trees, std::uint32_t, std::uint16_t>;
43  /* The type used to provide metadata storage for nodes */
44  using offset_type = std::conditional_t<large_trees, std::uint32_t, std::uint16_t>;
45  /* The tree layout (alias for layout_v)*/
46  auto static constexpr const layout = layout_v;
47  /* Whether or not this tree requires double precision (alias for
48  * double_precision)
49  */
50  auto static constexpr const is_double_precision = double_precision;
51  /* Whether or not this forest contains large trees (alias for
52  * large_trees)
53  */
54  auto static constexpr const has_large_trees = large_trees;
55 };
56 
57 /* A variant holding information on all specialization types compiled
58  * in standard cuML FIL
59  */
61  std::variant<specialization_types<tree_layout::depth_first, false, false>,
73 
74 } // namespace detail
75 } // namespace fil
76 } // namespace ML
std::variant< specialization_types< tree_layout::depth_first, false, false >, specialization_types< tree_layout::depth_first, false, true >, specialization_types< tree_layout::depth_first, true, false >, specialization_types< tree_layout::depth_first, true, true >, specialization_types< tree_layout::breadth_first, false, false >, specialization_types< tree_layout::breadth_first, false, true >, specialization_types< tree_layout::breadth_first, true, false >, specialization_types< tree_layout::breadth_first, true, true >, specialization_types< tree_layout::layered_children_together, false, false >, specialization_types< tree_layout::layered_children_together, false, true >, specialization_types< tree_layout::layered_children_together, true, false >, specialization_types< tree_layout::layered_children_together, true, true > > specialization_variant
Definition: specialization_types.hpp:72
Definition: dbscan.hpp:18
Definition: specialization_types.hpp:33
static constexpr auto const has_large_trees
Definition: specialization_types.hpp:54
std::conditional_t< double_precision, double, float > threshold_type
Definition: specialization_types.hpp:36
std::conditional_t< large_trees, std::uint32_t, std::uint16_t > metadata_type
Definition: specialization_types.hpp:42
std::conditional_t< double_precision, std::uint64_t, std::uint32_t > index_type
Definition: specialization_types.hpp:40
std::conditional_t< large_trees, std::uint32_t, std::uint16_t > offset_type
Definition: specialization_types.hpp:44
static constexpr auto const layout
Definition: specialization_types.hpp:46
static constexpr auto const is_double_precision
Definition: specialization_types.hpp:50