forest_macros.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 #include <cuml/fil/constants.hpp>
10 
11 #include <variant>
12 
13 /* Macro which, given a variant index, will extract the type of the
14  * corresponding variant from the specialization_variant type. This allows us
15  * to specify all forest variants we wish to support in one location and then
16  * reference them by index elsewhere. */
17 #define CUML_FIL_SPEC(variant_index) \
18  std::variant_alternative_t<variant_index, fil::detail::specialization_variant>
19 
20 /* Macro which expands to a full declaration of a forest type corresponding to
21  * the given variant index. */
22 #define CUML_FIL_FOREST(variant_index) \
23  forest<CUML_FIL_SPEC(variant_index)::layout, \
24  typename CUML_FIL_SPEC(variant_index)::threshold_type, \
25  typename CUML_FIL_SPEC(variant_index)::index_type, \
26  typename CUML_FIL_SPEC(variant_index)::metadata_type, \
27  typename CUML_FIL_SPEC(variant_index)::offset_type>