infer.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
12 #include <cuml/fil/exceptions.hpp>
13 #include <cuml/fil/infer_kind.hpp>
14 
15 #include <cstddef>
16 #include <iostream>
17 #include <optional>
18 #include <type_traits>
19 
20 #ifdef CUML_ENABLE_GPU
22 #endif
23 
24 namespace ML {
25 namespace fil {
26 namespace detail {
27 
28 /*
29  * Perform inference based on the given forest and input parameters
30  *
31  * @tparam D The device type (CPU/GPU) used to perform inference
32  * @tparam forest_t The type of the forest
33  * @param forest The forest to be evaluated
34  * @param postproc The postprocessor object used to execute
35  * postprocessing
36  * @param output Pointer to where the output should be written
37  * @param input Pointer to where the input data can be read from
38  * @param row_count The number of rows in the input data
39  * @param col_count The number of columns in the input data
40  * @param output_count The number of outputs per row
41  * @param has_categorical_nodes Whether or not any node within the forest has
42  * a categorical split
43  * @param vector_output Pointer to the beginning of storage for vector
44  * outputs of leaves (nullptr for no vector output)
45  * @param categorical_data Pointer to external categorical data storage if
46  * required
47  * @param infer_type Type of inference to perform. Defaults to summing the outputs of all trees
48  * and produce an output per row. If set to "per_tree", we will instead output all outputs of
49  * individual trees. If set to "leaf_id", we will output the integer ID of the leaf node
50  * for each tree.
51  * @param specified_chunk_size If non-nullopt, the size of "mini-batches"
52  * used for distributing work across threads
53  * @param device The device on which to execute evaluation
54  * @param stream Optionally, the CUDA stream to use
55  */
56 template <raft_proto::device_type D, typename forest_t>
57 void infer(forest_t const& forest,
59  typename forest_t::io_type* output,
60  typename forest_t::io_type* input,
61  index_type row_count,
62  index_type col_count,
63  index_type output_count,
64  bool has_categorical_nodes,
65  typename forest_t::io_type* vector_output = nullptr,
66  typename forest_t::node_type::index_type* categorical_data = nullptr,
68  std::optional<index_type> specified_chunk_size = std::nullopt,
71 {
72  if (vector_output == nullptr) {
73  if (categorical_data == nullptr) {
74  if (!has_categorical_nodes) {
75  inference::infer<D, false, forest_t, std::nullptr_t, std::nullptr_t>(forest,
76  postproc,
77  output,
78  input,
79  row_count,
80  col_count,
81  output_count,
82  nullptr,
83  nullptr,
84  infer_type,
85  specified_chunk_size,
86  device,
87  stream);
88  } else {
89  inference::infer<D, true, forest_t, std::nullptr_t, std::nullptr_t>(forest,
90  postproc,
91  output,
92  input,
93  row_count,
94  col_count,
95  output_count,
96  nullptr,
97  nullptr,
98  infer_type,
99  specified_chunk_size,
100  device,
101  stream);
102  }
103  } else {
104  inference::infer<D, true, forest_t>(forest,
105  postproc,
106  output,
107  input,
108  row_count,
109  col_count,
110  output_count,
111  nullptr,
112  categorical_data,
113  infer_type,
114  specified_chunk_size,
115  device,
116  stream);
117  }
118  } else {
119  if (categorical_data == nullptr) {
120  if (!has_categorical_nodes) {
121  inference::infer<D, false, forest_t>(forest,
122  postproc,
123  output,
124  input,
125  row_count,
126  col_count,
127  output_count,
128  vector_output,
129  nullptr,
130  infer_type,
131  specified_chunk_size,
132  device,
133  stream);
134  } else {
135  inference::infer<D, true, forest_t>(forest,
136  postproc,
137  output,
138  input,
139  row_count,
140  col_count,
141  output_count,
142  vector_output,
143  nullptr,
144  infer_type,
145  specified_chunk_size,
146  device,
147  stream);
148  }
149  } else {
150  inference::infer<D, true, forest_t>(forest,
151  postproc,
152  output,
153  input,
154  row_count,
155  col_count,
156  output_count,
157  vector_output,
158  categorical_data,
159  infer_type,
160  specified_chunk_size,
161  device,
162  stream);
163  }
164  }
165 }
166 
167 } // namespace detail
168 } // namespace fil
169 } // namespace ML
void infer(forest_t const &forest, postprocessor< typename forest_t::io_type > const &postproc, typename forest_t::io_type *output, typename forest_t::io_type *input, index_type row_count, index_type col_count, index_type output_count, bool has_categorical_nodes, typename forest_t::io_type *vector_output=nullptr, typename forest_t::node_type::index_type *categorical_data=nullptr, infer_kind infer_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt, raft_proto::device_id< D > device=raft_proto::device_id< D >{}, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: infer.hpp:57
infer_kind
Definition: infer_kind.hpp:8
uint32_t index_type
Definition: index_type.hpp:9
Definition: dbscan.hpp:18
int cuda_stream
Definition: cuda_stream.hpp:14
Definition: forest.hpp:24
Definition: postprocessor.hpp:135
Definition: base.hpp:11