infer.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
23 #include <cuml/fil/exceptions.hpp>
24 #include <cuml/fil/infer_kind.hpp>
25 
26 #include <cstddef>
27 #include <iostream>
28 #include <optional>
29 #include <type_traits>
30 
31 #ifdef CUML_ENABLE_GPU
33 #endif
34 
35 namespace ML {
36 namespace fil {
37 namespace detail {
38 
39 /*
40  * Perform inference based on the given forest and input parameters
41  *
42  * @tparam D The device type (CPU/GPU) used to perform inference
43  * @tparam forest_t The type of the forest
44  * @param forest The forest to be evaluated
45  * @param postproc The postprocessor object used to execute
46  * postprocessing
47  * @param output Pointer to where the output should be written
48  * @param input Pointer to where the input data can be read from
49  * @param row_count The number of rows in the input data
50  * @param col_count The number of columns in the input data
51  * @param output_count The number of outputs per row
52  * @param has_categorical_nodes Whether or not any node within the forest has
53  * a categorical split
54  * @param vector_output Pointer to the beginning of storage for vector
55  * outputs of leaves (nullptr for no vector output)
56  * @param categorical_data Pointer to external categorical data storage if
57  * required
58  * @param infer_type Type of inference to perform. Defaults to summing the outputs of all trees
59  * and produce an output per row. If set to "per_tree", we will instead output all outputs of
60  * individual trees. If set to "leaf_id", we will output the integer ID of the leaf node
61  * for each tree.
62  * @param specified_chunk_size If non-nullopt, the size of "mini-batches"
63  * used for distributing work across threads
64  * @param device The device on which to execute evaluation
65  * @param stream Optionally, the CUDA stream to use
66  */
67 template <raft_proto::device_type D, typename forest_t>
68 void infer(forest_t const& forest,
70  typename forest_t::io_type* output,
71  typename forest_t::io_type* input,
72  index_type row_count,
73  index_type col_count,
74  index_type output_count,
75  bool has_categorical_nodes,
76  typename forest_t::io_type* vector_output = nullptr,
77  typename forest_t::node_type::index_type* categorical_data = nullptr,
79  std::optional<index_type> specified_chunk_size = std::nullopt,
82 {
83  if (vector_output == nullptr) {
84  if (categorical_data == nullptr) {
85  if (!has_categorical_nodes) {
86  inference::infer<D, false, forest_t, std::nullptr_t, std::nullptr_t>(forest,
87  postproc,
88  output,
89  input,
90  row_count,
91  col_count,
92  output_count,
93  nullptr,
94  nullptr,
95  infer_type,
96  specified_chunk_size,
97  device,
98  stream);
99  } else {
100  inference::infer<D, true, forest_t, std::nullptr_t, std::nullptr_t>(forest,
101  postproc,
102  output,
103  input,
104  row_count,
105  col_count,
106  output_count,
107  nullptr,
108  nullptr,
109  infer_type,
110  specified_chunk_size,
111  device,
112  stream);
113  }
114  } else {
115  inference::infer<D, true, forest_t>(forest,
116  postproc,
117  output,
118  input,
119  row_count,
120  col_count,
121  output_count,
122  nullptr,
123  categorical_data,
124  infer_type,
125  specified_chunk_size,
126  device,
127  stream);
128  }
129  } else {
130  if (categorical_data == nullptr) {
131  if (!has_categorical_nodes) {
132  inference::infer<D, false, forest_t>(forest,
133  postproc,
134  output,
135  input,
136  row_count,
137  col_count,
138  output_count,
139  vector_output,
140  nullptr,
141  infer_type,
142  specified_chunk_size,
143  device,
144  stream);
145  } else {
146  inference::infer<D, true, forest_t>(forest,
147  postproc,
148  output,
149  input,
150  row_count,
151  col_count,
152  output_count,
153  vector_output,
154  nullptr,
155  infer_type,
156  specified_chunk_size,
157  device,
158  stream);
159  }
160  } else {
161  inference::infer<D, true, forest_t>(forest,
162  postproc,
163  output,
164  input,
165  row_count,
166  col_count,
167  output_count,
168  vector_output,
169  categorical_data,
170  infer_type,
171  specified_chunk_size,
172  device,
173  stream);
174  }
175  }
176 }
177 
178 } // namespace detail
179 } // namespace fil
180 } // namespace ML
void infer(forest_t const &forest, postprocessor< typename forest_t::io_type > const &postproc, typename forest_t::io_type *output, typename forest_t::io_type *input, index_type row_count, index_type col_count, index_type output_count, bool has_categorical_nodes, typename forest_t::io_type *vector_output=nullptr, typename forest_t::node_type::index_type *categorical_data=nullptr, infer_kind infer_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt, raft_proto::device_id< D > device=raft_proto::device_id< D >{}, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: infer.hpp:68
infer_kind
Definition: infer_kind.hpp:19
uint32_t index_type
Definition: index_type.hpp:20
Definition: dbscan.hpp:29
int cuda_stream
Definition: cuda_stream.hpp:25
Definition: forest.hpp:35
Definition: postprocessor.hpp:140
Definition: base.hpp:22