infer_macros.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
27 
28 #include <cstddef>
29 #include <variant>
30 
31 /* Macro which expands to the valid arguments to an inference call for a forest
32  * model without vector leaves or non-local categorical data.*/
33 #define CUML_FIL_SCALAR_LOCAL_ARGS(dev, variant_index) \
34  (CUML_FIL_FOREST(variant_index) const&, \
35  postprocessor<CUML_FIL_SPEC(variant_index)::threshold_type> const&, \
36  CUML_FIL_SPEC(variant_index)::threshold_type*, \
37  CUML_FIL_SPEC(variant_index)::threshold_type*, \
38  index_type, \
39  index_type, \
40  index_type, \
41  std::nullptr_t, \
42  std::nullptr_t, \
43  infer_kind, \
44  std::optional<index_type>, \
45  raft_proto::device_id<dev>, \
46  raft_proto::cuda_stream stream)
47 
48 /* Macro which expands to the valid arguments to an inference call for a forest
49  * model with vector leaves but without non-local categorical data.*/
50 #define CUML_FIL_VECTOR_LOCAL_ARGS(dev, variant_index) \
51  (CUML_FIL_FOREST(variant_index) const&, \
52  postprocessor<CUML_FIL_SPEC(variant_index)::threshold_type> const&, \
53  CUML_FIL_SPEC(variant_index)::threshold_type*, \
54  CUML_FIL_SPEC(variant_index)::threshold_type*, \
55  index_type, \
56  index_type, \
57  index_type, \
58  CUML_FIL_SPEC(variant_index)::threshold_type*, \
59  std::nullptr_t, \
60  infer_kind, \
61  std::optional<index_type>, \
62  raft_proto::device_id<dev>, \
63  raft_proto::cuda_stream stream)
64 
65 /* Macro which expands to the valid arguments to an inference call for a forest
66  * model without vector leaves but with non-local categorical data.*/
67 #define CUML_FIL_SCALAR_NONLOCAL_ARGS(dev, variant_index) \
68  (CUML_FIL_FOREST(variant_index) const&, \
69  postprocessor<CUML_FIL_SPEC(variant_index)::threshold_type> const&, \
70  CUML_FIL_SPEC(variant_index)::threshold_type*, \
71  CUML_FIL_SPEC(variant_index)::threshold_type*, \
72  index_type, \
73  index_type, \
74  index_type, \
75  std::nullptr_t, \
76  CUML_FIL_SPEC(variant_index)::index_type*, \
77  infer_kind, \
78  std::optional<index_type>, \
79  raft_proto::device_id<dev>, \
80  raft_proto::cuda_stream stream)
81 
82 /* Macro which expands to the valid arguments to an inference call for a forest
83  * model with vector leaves and with non-local categorical data.*/
84 #define CUML_FIL_VECTOR_NONLOCAL_ARGS(dev, variant_index) \
85  (CUML_FIL_FOREST(variant_index) const&, \
86  postprocessor<CUML_FIL_SPEC(variant_index)::threshold_type> const&, \
87  CUML_FIL_SPEC(variant_index)::threshold_type*, \
88  CUML_FIL_SPEC(variant_index)::threshold_type*, \
89  index_type, \
90  index_type, \
91  index_type, \
92  CUML_FIL_SPEC(variant_index)::threshold_type*, \
93  CUML_FIL_SPEC(variant_index)::index_type*, \
94  infer_kind, \
95  std::optional<index_type>, \
96  raft_proto::device_id<dev>, \
97  raft_proto::cuda_stream stream)
98 
99 /* Macro which expands to the declaration of an inference template for a forest
100  * of the type indicated by the variant index */
101 #define CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, categorical) \
102  template_type void infer<dev, categorical, CUML_FIL_FOREST(variant_index)>
103 
104 /* Macro which expands to the declaration of an inference template for a forest
105  * of the type indicated by the variant index on the given device type without
106  * vector leaves or categorical nodes*/
107 #define CUML_FIL_INFER_DEV_SCALAR_LEAF_NO_CAT(template_type, dev, variant_index) \
108  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, false) \
109  CUML_FIL_SCALAR_LOCAL_ARGS(dev, variant_index);
110 
111 /* Macro which expands to the declaration of an inference template for a forest
112  * of the type indicated by the variant index on the given device type without
113  * vector leaves and with only local categorical nodes*/
114 #define CUML_FIL_INFER_DEV_SCALAR_LEAF_LOCAL_CAT(template_type, dev, variant_index) \
115  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, true) \
116  CUML_FIL_SCALAR_LOCAL_ARGS(dev, variant_index);
117 
118 /* Macro which expands to the declaration of an inference template for a forest
119  * of the type indicated by the variant index on the given device type without
120  * vector leaves and with non-local categorical nodes*/
121 #define CUML_FIL_INFER_DEV_SCALAR_LEAF_NONLOCAL_CAT(template_type, dev, variant_index) \
122  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, true) \
123  CUML_FIL_SCALAR_NONLOCAL_ARGS(dev, variant_index);
124 
125 /* Macro which expands to the declaration of an inference template for a forest
126  * of the type indicated by the variant index on the given device type with
127  * vector leaves and without categorical nodes*/
128 #define CUML_FIL_INFER_DEV_VECTOR_LEAF_NO_CAT(template_type, dev, variant_index) \
129  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, false) \
130  CUML_FIL_VECTOR_LOCAL_ARGS(dev, variant_index);
131 
132 /* Macro which expands to the declaration of an inference template for a forest
133  * of the type indicated by the variant index on the given device type with
134  * vector leaves and with only local categorical nodes*/
135 #define CUML_FIL_INFER_DEV_VECTOR_LEAF_LOCAL_CAT(template_type, dev, variant_index) \
136  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, true) \
137  CUML_FIL_VECTOR_LOCAL_ARGS(dev, variant_index);
138 
139 /* Macro which expands to the declaration of an inference template for a forest
140  * of the type indicated by the variant index on the given device type with
141  * vector leaves and with non-local categorical nodes*/
142 #define CUML_FIL_INFER_DEV_VECTOR_LEAF_NONLOCAL_CAT(template_type, dev, variant_index) \
143  CUML_FIL_INFER_TEMPLATE(template_type, dev, variant_index, true) \
144  CUML_FIL_VECTOR_NONLOCAL_ARGS(dev, variant_index);
145 
146 /* Macro which expands to the declaration of all valid inference templates for
147  * the given device on the forest type specified by the given variant index */
148 #define CUML_FIL_INFER_ALL(template_type, dev, variant_index) \
149  CUML_FIL_INFER_DEV_SCALAR_LEAF_NO_CAT(template_type, dev, variant_index) \
150  CUML_FIL_INFER_DEV_SCALAR_LEAF_LOCAL_CAT(template_type, dev, variant_index) \
151  CUML_FIL_INFER_DEV_SCALAR_LEAF_NONLOCAL_CAT(template_type, dev, variant_index) \
152  CUML_FIL_INFER_DEV_VECTOR_LEAF_NO_CAT(template_type, dev, variant_index) \
153  CUML_FIL_INFER_DEV_VECTOR_LEAF_LOCAL_CAT(template_type, dev, variant_index) \
154  CUML_FIL_INFER_DEV_VECTOR_LEAF_NONLOCAL_CAT(template_type, dev, variant_index)