infer.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
25 
26 #include <cstddef>
27 #include <iostream>
28 #include <optional>
29 #include <type_traits>
30 
31 #ifdef CUML_ENABLE_GPU
33 #endif
34 
35 namespace ML {
36 namespace experimental {
37 namespace fil {
38 namespace detail {
39 
40 /*
41  * Perform inference based on the given forest and input parameters
42  *
43  * @tparam D The device type (CPU/GPU) used to perform inference
44  * @tparam forest_t The type of the forest
45  * @param forest The forest to be evaluated
46  * @param postproc The postprocessor object used to execute
47  * postprocessing
48  * @param output Pointer to where the output should be written
49  * @param input Pointer to where the input data can be read from
50  * @param row_count The number of rows in the input data
51  * @param col_count The number of columns in the input data
52  * @param output_count The number of outputs per row
53  * @param has_categorical_nodes Whether or not any node within the forest has
54  * a categorical split
55  * @param vector_output Pointer to the beginning of storage for vector
56  * outputs of leaves (nullptr for no vector output)
57  * @param categorical_data Pointer to external categorical data storage if
58  * required
59  * @param infer_type Type of inference to perform. Defaults to summing the outputs of all trees
60  * and produce an output per row. If set to "per_tree", we will instead output all outputs of
61  * individual trees. If set to "leaf_id", we will output the integer ID of the leaf node
62  * for each tree.
63  * @param specified_chunk_size If non-nullopt, the size of "mini-batches"
64  * used for distributing work across threads
65  * @param device The device on which to execute evaluation
66  * @param stream Optionally, the CUDA stream to use
67  */
68 template <raft_proto::device_type D, typename forest_t>
69 void infer(forest_t const& forest,
71  typename forest_t::io_type* output,
72  typename forest_t::io_type* input,
73  index_type row_count,
74  index_type col_count,
75  index_type output_count,
76  bool has_categorical_nodes,
77  typename forest_t::io_type* vector_output = nullptr,
78  typename forest_t::node_type::index_type* categorical_data = nullptr,
80  std::optional<index_type> specified_chunk_size = std::nullopt,
83 {
84  if (vector_output == nullptr) {
85  if (categorical_data == nullptr) {
86  if (!has_categorical_nodes) {
87  inference::infer<D, false, forest_t, std::nullptr_t, std::nullptr_t>(forest,
88  postproc,
89  output,
90  input,
91  row_count,
92  col_count,
93  output_count,
94  nullptr,
95  nullptr,
96  infer_type,
97  specified_chunk_size,
98  device,
99  stream);
100  } else {
101  inference::infer<D, true, forest_t, std::nullptr_t, std::nullptr_t>(forest,
102  postproc,
103  output,
104  input,
105  row_count,
106  col_count,
107  output_count,
108  nullptr,
109  nullptr,
110  infer_type,
111  specified_chunk_size,
112  device,
113  stream);
114  }
115  } else {
116  inference::infer<D, true, forest_t>(forest,
117  postproc,
118  output,
119  input,
120  row_count,
121  col_count,
122  output_count,
123  nullptr,
124  categorical_data,
125  infer_type,
126  specified_chunk_size,
127  device,
128  stream);
129  }
130  } else {
131  if (categorical_data == nullptr) {
132  if (!has_categorical_nodes) {
133  inference::infer<D, false, forest_t>(forest,
134  postproc,
135  output,
136  input,
137  row_count,
138  col_count,
139  output_count,
140  vector_output,
141  nullptr,
142  infer_type,
143  specified_chunk_size,
144  device,
145  stream);
146  } else {
147  inference::infer<D, true, forest_t>(forest,
148  postproc,
149  output,
150  input,
151  row_count,
152  col_count,
153  output_count,
154  vector_output,
155  nullptr,
156  infer_type,
157  specified_chunk_size,
158  device,
159  stream);
160  }
161  } else {
162  inference::infer<D, true, forest_t>(forest,
163  postproc,
164  output,
165  input,
166  row_count,
167  col_count,
168  output_count,
169  vector_output,
170  categorical_data,
171  infer_type,
172  specified_chunk_size,
173  device,
174  stream);
175  }
176  }
177 }
178 
179 } // namespace detail
180 } // namespace fil
181 } // namespace experimental
182 } // namespace ML
void infer(forest_t const &forest, postprocessor< typename forest_t::io_type > const &postproc, typename forest_t::io_type *output, typename forest_t::io_type *input, index_type row_count, index_type col_count, index_type output_count, bool has_categorical_nodes, typename forest_t::io_type *vector_output=nullptr, typename forest_t::node_type::index_type *categorical_data=nullptr, infer_kind infer_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt, raft_proto::device_id< D > device=raft_proto::device_id< D >{}, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: infer.hpp:69
uint32_t index_type
Definition: index_type.hpp:21
infer_kind
Definition: infer_kind.hpp:20
forest< real_t > * forest_t
Definition: fil.h:89
Definition: dbscan.hpp:30
int cuda_stream
Definition: cuda_stream.hpp:25
Definition: forest.hpp:36
Definition: postprocessor.hpp:141
Definition: base.hpp:22