postprocessor.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
21 
22 #include <stddef.h>
23 
24 #include <limits>
25 #include <type_traits>
26 
27 #ifndef __CUDACC__
28 #include <math.h>
29 #endif
30 
31 namespace ML {
32 namespace fil {
33 
34 /* Convert the postprocessing operations into a single value
35  * representing what must be done in the inference kernel
36  */
37 HOST DEVICE inline auto constexpr ops_to_val(row_op row_wise, element_op elem_wise)
38 {
39  return (static_cast<std::underlying_type_t<row_op>>(row_wise) |
40  static_cast<std::underlying_type_t<element_op>>(elem_wise));
41 }
42 
43 /*
44  * Perform postprocessing on raw forest output
45  *
46  * @param val Pointer to the raw forest output
47  * @param output_count The number of output values per row
48  * @param out Pointer to the output buffer
49  * @param stride Number of elements between the first element that must be
50  * summed for a particular output element and the next. This is typically
51  * equal to the number of "groves" of trees over which the computation
52  * was divided.
53  * @param average_factor The factor by which to divide during the
54  * normalization step of postprocessing
55  * @param bias The bias factor to subtract off during the
56  * normalization step of postprocessing
57  * @param constant If the postprocessing operation requires a constant,
58  * it can be passed here.
59  */
60 template <row_op row_wise_v, element_op elem_wise_v, typename io_t>
61 HOST DEVICE void postprocess(io_t* val,
62  index_type output_count,
63  io_t* out,
64  index_type stride = index_type{1},
65  io_t average_factor = io_t{1},
66  io_t bias = io_t{0},
67  io_t constant = io_t{1})
68 {
69 #pragma GCC diagnostic push
70 #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
71  auto max_index = index_type{};
72  auto max_value = std::numeric_limits<io_t>::lowest();
73 #pragma GCC diagnostic pop
74  for (auto output_index = index_type{}; output_index < output_count; ++output_index) {
75  auto workspace_index = output_index * stride;
76  val[workspace_index] = val[workspace_index] / average_factor + bias;
77  if constexpr (elem_wise_v == element_op::signed_square) {
78  val[workspace_index] =
79  copysign(val[workspace_index] * val[workspace_index], val[workspace_index]);
80  } else if constexpr (elem_wise_v == element_op::hinge) {
81  val[workspace_index] = io_t(val[workspace_index] > io_t{});
82  } else if constexpr (elem_wise_v == element_op::sigmoid) {
83  val[workspace_index] = io_t{1} / (io_t{1} + exp(-constant * val[workspace_index]));
84  } else if constexpr (elem_wise_v == element_op::exponential) {
85  val[workspace_index] = exp(val[workspace_index] / constant);
86  } else if constexpr (elem_wise_v == element_op::logarithm_one_plus_exp) {
87  val[workspace_index] = log1p(exp(val[workspace_index] / constant));
88  }
89  if constexpr (row_wise_v == row_op::softmax || row_wise_v == row_op::max_index) {
90  auto is_new_max = val[workspace_index] > max_value;
91  max_index = is_new_max * output_index + (!is_new_max) * max_index;
92  max_value = is_new_max * val[workspace_index] + (!is_new_max) * max_value;
93  }
94  }
95 
96  if constexpr (row_wise_v == row_op::max_index) {
97  *out = max_index;
98  } else {
99 #pragma GCC diagnostic push
100 #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
101  auto softmax_normalization = io_t{};
102 #pragma GCC diagnostic pop
103  if constexpr (row_wise_v == row_op::softmax) {
104  for (auto workspace_index = index_type{}; workspace_index < output_count * stride;
105  workspace_index += stride) {
106  val[workspace_index] = exp(val[workspace_index] - max_value);
107  softmax_normalization += val[workspace_index];
108  }
109  }
110 
111  for (auto output_index = index_type{}; output_index < output_count; ++output_index) {
112  auto workspace_index = output_index * stride;
113  if constexpr (row_wise_v == row_op::softmax) {
114  out[output_index] = val[workspace_index] / softmax_normalization;
115  } else {
116  out[output_index] = val[workspace_index];
117  }
118  }
119  }
120 }
121 
122 /*
123  * Struct which holds all data necessary to perform postprocessing on raw
124  * output of a forest model
125  *
126  * @tparam io_t The type used for input and output to/from the model
127  * (typically float/double)
128  * @param row_wise Enum value representing the row-wise post-processing
129  * operation to perform on the output
130  * @param elem_wise Enum value representing the element-wise post-processing
131  * operation to perform on the output
132  * @param average_factor The factor by which to divide during the
133  * normalization step of postprocessing
134  * @param bias The bias factor to subtract off during the
135  * normalization step of postprocessing
136  * @param constant If the postprocessing operation requires a constant,
137  * it can be passed here.
138  */
139 template <typename io_t>
142  element_op elem_wise = element_op::disable,
143  io_t average_factor = io_t{1},
144  io_t bias = io_t{0},
145  io_t constant = io_t{1})
146  : average_factor_{average_factor},
147  bias_{bias},
148  constant_{constant},
149  row_wise_{row_wise},
150  elem_wise_{elem_wise}
151  {
152  }
153 
154  HOST DEVICE void operator()(io_t* val,
155  index_type output_count,
156  io_t* out,
157  index_type stride = index_type{1}) const
158  {
159  switch (ops_to_val(row_wise_, elem_wise_)) {
162  val, output_count, out, stride, average_factor_, bias_, constant_);
163  break;
166  val, output_count, out, stride, average_factor_, bias_, constant_);
167  break;
170  val, output_count, out, stride, average_factor_, bias_, constant_);
171  break;
174  val, output_count, out, stride, average_factor_, bias_, constant_);
175  break;
178  val, output_count, out, stride, average_factor_, bias_, constant_);
179  break;
181  postprocess<row_op::softmax, element_op::disable>(
182  val, output_count, out, stride, average_factor_, bias_, constant_);
183  break;
186  val, output_count, out, stride, average_factor_, bias_, constant_);
187  break;
189  postprocess<row_op::softmax, element_op::hinge>(
190  val, output_count, out, stride, average_factor_, bias_, constant_);
191  break;
193  postprocess<row_op::softmax, element_op::sigmoid>(
194  val, output_count, out, stride, average_factor_, bias_, constant_);
195  break;
198  val, output_count, out, stride, average_factor_, bias_, constant_);
199  break;
202  val, output_count, out, stride, average_factor_, bias_, constant_);
203  break;
206  val, output_count, out, stride, average_factor_, bias_, constant_);
207  break;
210  val, output_count, out, stride, average_factor_, bias_, constant_);
211  break;
214  val, output_count, out, stride, average_factor_, bias_, constant_);
215  break;
218  val, output_count, out, stride, average_factor_, bias_, constant_);
219  break;
222  val, output_count, out, stride, average_factor_, bias_, constant_);
223  break;
226  val, output_count, out, stride, average_factor_, bias_, constant_);
227  break;
228  default:
229  postprocess<row_op::disable, element_op::disable>(
230  val, output_count, out, stride, average_factor_, bias_, constant_);
231  }
232  }
233 
234  private:
235  io_t average_factor_;
236  io_t bias_;
237  io_t constant_;
238  row_op row_wise_;
239  element_op elem_wise_;
240 };
241 } // namespace fil
242 } // namespace ML
#define DEVICE
Definition: gpu_support.hpp:35
#define HOST
Definition: gpu_support.hpp:34
row_op
Definition: postproc_ops.hpp:21
element_op
Definition: postproc_ops.hpp:28
HOST DEVICE constexpr auto ops_to_val(row_op row_wise, element_op elem_wise)
Definition: postprocessor.hpp:37
HOST DEVICE void postprocess(io_t *val, index_type output_count, io_t *out, index_type stride=index_type{1}, io_t average_factor=io_t{1}, io_t bias=io_t{0}, io_t constant=io_t{1})
Definition: postprocessor.hpp:61
uint32_t index_type
Definition: index_type.hpp:20
Definition: dbscan.hpp:29
Definition: postprocessor.hpp:140
HOST DEVICE postprocessor(row_op row_wise=row_op::disable, element_op elem_wise=element_op::disable, io_t average_factor=io_t{1}, io_t bias=io_t{0}, io_t constant=io_t{1})
Definition: postprocessor.hpp:141
HOST DEVICE void operator()(io_t *val, index_type output_count, io_t *out, index_type stride=index_type{1}) const
Definition: postprocessor.hpp:154