infer_macros.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 #include <cuml/fil/constants.hpp>
15 #include <cuml/fil/infer_kind.hpp>
16 
17 #include <cstddef>
18 #include <variant>
19 
20 /* Macro which expands to the valid arguments to an inference call for a forest
21  * model without vector leaves or non-local categorical data.*/
22 #define CUML_FIL_SCALAR_LOCAL_ARGS(dev, variant_index) \
23  (CUML_FIL_FOREST(variant_index) const&, \
24  postprocessor<CUML_FIL_SPEC(variant_index)::threshold_type> const&, \
25  CUML_FIL_SPEC(variant_index)::threshold_type*, \
26  CUML_FIL_SPEC(variant_index)::threshold_type*, \
27  index_type, \
28  index_type, \
29  index_type, \
30  std::nullptr_t, \
31  std::nullptr_t, \
32  infer_kind, \
33  std::optional<index_type>, \
34  raft_proto::device_id<dev>, \
35  raft_proto::cuda_stream stream)
36 
37 /* Macro which expands to the valid arguments to an inference call for a forest
38  * model with vector leaves but without non-local categorical data.*/
39 #define CUML_FIL_VECTOR_LOCAL_ARGS(dev, variant_index) \
40  (CUML_FIL_FOREST(variant_index) const&, \
41  postprocessor<CUML_FIL_SPEC(variant_index)::threshold_type> const&, \
42  CUML_FIL_SPEC(variant_index)::threshold_type*, \
43  CUML_FIL_SPEC(variant_index)::threshold_type*, \
44  index_type, \
45  index_type, \
46  index_type, \
47  CUML_FIL_SPEC(variant_index)::threshold_type*, \
48  std::nullptr_t, \
49  infer_kind, \
50  std::optional<index_type>, \
51  raft_proto::device_id<dev>, \
52  raft_proto::cuda_stream stream)
53 
54 /* Macro which expands to the valid arguments to an inference call for a forest
55  * model without vector leaves but with non-local categorical data.*/
56 #define CUML_FIL_SCALAR_NONLOCAL_ARGS(dev, variant_index) \
57  (CUML_FIL_FOREST(variant_index) const&, \
58  postprocessor<CUML_FIL_SPEC(variant_index)::threshold_type> const&, \
59  CUML_FIL_SPEC(variant_index)::threshold_type*, \
60  CUML_FIL_SPEC(variant_index)::threshold_type*, \
61  index_type, \
62  index_type, \
63  index_type, \
64  std::nullptr_t, \
65  CUML_FIL_SPEC(variant_index)::index_type*, \
66  infer_kind, \
67  std::optional<index_type>, \
68  raft_proto::device_id<dev>, \
69  raft_proto::cuda_stream stream)
70 
71 /* Macro which expands to the valid arguments to an inference call for a forest
72  * model with vector leaves and with non-local categorical data.*/
73 #define CUML_FIL_VECTOR_NONLOCAL_ARGS(dev, variant_index) \
74  (CUML_FIL_FOREST(variant_index) const&, \
75  postprocessor<CUML_FIL_SPEC(variant_index)::threshold_type> const&, \
76  CUML_FIL_SPEC(variant_index)::threshold_type*, \
77  CUML_FIL_SPEC(variant_index)::threshold_type*, \
78  index_type, \
79  index_type, \
80  index_type, \
81  CUML_FIL_SPEC(variant_index)::threshold_type*, \
82  CUML_FIL_SPEC(variant_index)::index_type*, \
83  infer_kind, \
84  std::optional<index_type>, \
85  raft_proto::device_id<dev>, \
86  raft_proto::cuda_stream stream)
87 
88 /* Macro which expands to the declaration of an inference template for a forest
89  * of the type indicated by the variant index */
90 #define CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, categorical) \
91  template_type void infer<dev, categorical, CUML_FIL_FOREST(variant_index)>
92 
93 /* Macro which expands to the declaration of an inference template for a forest
94  * of the type indicated by the variant index on the given device type without
95  * vector leaves or categorical nodes*/
96 #define CUML_FIL_INFER_DEV_SCALAR_LEAF_NO_CAT(template_type, dev, variant_index) \
97  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, false) \
98  CUML_FIL_SCALAR_LOCAL_ARGS(dev, variant_index);
99 
100 /* Macro which expands to the declaration of an inference template for a forest
101  * of the type indicated by the variant index on the given device type without
102  * vector leaves and with only local categorical nodes*/
103 #define CUML_FIL_INFER_DEV_SCALAR_LEAF_LOCAL_CAT(template_type, dev, variant_index) \
104  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, true) \
105  CUML_FIL_SCALAR_LOCAL_ARGS(dev, variant_index);
106 
107 /* Macro which expands to the declaration of an inference template for a forest
108  * of the type indicated by the variant index on the given device type without
109  * vector leaves and with non-local categorical nodes*/
110 #define CUML_FIL_INFER_DEV_SCALAR_LEAF_NONLOCAL_CAT(template_type, dev, variant_index) \
111  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, true) \
112  CUML_FIL_SCALAR_NONLOCAL_ARGS(dev, variant_index);
113 
114 /* Macro which expands to the declaration of an inference template for a forest
115  * of the type indicated by the variant index on the given device type with
116  * vector leaves and without categorical nodes*/
117 #define CUML_FIL_INFER_DEV_VECTOR_LEAF_NO_CAT(template_type, dev, variant_index) \
118  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, false) \
119  CUML_FIL_VECTOR_LOCAL_ARGS(dev, variant_index);
120 
121 /* Macro which expands to the declaration of an inference template for a forest
122  * of the type indicated by the variant index on the given device type with
123  * vector leaves and with only local categorical nodes*/
124 #define CUML_FIL_INFER_DEV_VECTOR_LEAF_LOCAL_CAT(template_type, dev, variant_index) \
125  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, true) \
126  CUML_FIL_VECTOR_LOCAL_ARGS(dev, variant_index);
127 
128 /* Macro which expands to the declaration of an inference template for a forest
129  * of the type indicated by the variant index on the given device type with
130  * vector leaves and with non-local categorical nodes*/
131 #define CUML_FIL_INFER_DEV_VECTOR_LEAF_NONLOCAL_CAT(template_type, dev, variant_index) \
132  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, true) \
133  CUML_FIL_VECTOR_NONLOCAL_ARGS(dev, variant_index);
134 
135 /* Macro which expands to the declaration of all valid inference templates for
136  * the given device on the forest type specified by the given variant index */
137 #define CUML_FIL_INFER_ALL(template_type, dev, variant_index) \
138  CUML_FIL_INFER_DEV_SCALAR_LEAF_NO_CAT(template_type, dev, variant_index) \
139  CUML_FIL_INFER_DEV_SCALAR_LEAF_LOCAL_CAT(template_type, dev, variant_index) \
140  CUML_FIL_INFER_DEV_SCALAR_LEAF_NONLOCAL_CAT(template_type, dev, variant_index) \
141  CUML_FIL_INFER_DEV_VECTOR_LEAF_NO_CAT(template_type, dev, variant_index) \
142  CUML_FIL_INFER_DEV_VECTOR_LEAF_LOCAL_CAT(template_type, dev, variant_index) \
143  CUML_FIL_INFER_DEV_VECTOR_LEAF_NONLOCAL_CAT(template_type, dev, variant_index)