postprocessor.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
11 
12 #include <stddef.h>
13 
14 #include <limits>
15 #include <type_traits>
16 
17 #ifndef __CUDACC__
18 #include <math.h>
19 #endif
20 
21 namespace ML {
22 namespace fil {
23 
24 /* Convert the postprocessing operations into a single value
25  * representing what must be done in the inference kernel
26  */
27 HOST DEVICE inline auto constexpr ops_to_val(row_op row_wise, element_op elem_wise)
28 {
29  return (static_cast<std::underlying_type_t<row_op>>(row_wise) |
30  static_cast<std::underlying_type_t<element_op>>(elem_wise));
31 }
32 
33 /*
34  * Perform postprocessing on raw forest output
35  *
36  * @param val Pointer to the raw forest output
37  * @param output_count The number of output values per row
38  * @param bias Pointer to bias vector, which is added to the output
39  * as part of the postprocessing step. The bias vector should have
40  * the same length as output_count.
41  * @param out Pointer to the output buffer
42  * @param stride Number of elements between the first element that must be
43  * summed for a particular output element and the next. This is typically
44  * equal to the number of "groves" of trees over which the computation
45  * was divided.
46  * @param average_factor The factor by which to divide during the
47  * normalization step of postprocessing
48  * @param constant If the postprocessing operation requires a constant,
49  * it can be passed here.
50  */
51 template <row_op row_wise_v, element_op elem_wise_v, typename io_t>
53  io_t* val,
54  index_type output_count,
55  const io_t* bias,
56  io_t* out,
57  index_type stride = index_type{1},
58  io_t average_factor = io_t{1},
59  io_t constant = io_t{1})
60 {
61  const bool use_bias = infer_type == infer_kind::default_kind;
62 #pragma GCC diagnostic push
63 #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
64  auto max_index = index_type{};
65  auto max_value = std::numeric_limits<io_t>::lowest();
66 #pragma GCC diagnostic pop
67  for (auto output_index = index_type{}; output_index < output_count; ++output_index) {
68  auto workspace_index = output_index * stride;
69  // Add the bias term if use_bias is true.
70  // The following expression is written to avoid branching.
71  val[workspace_index] =
72  val[workspace_index] / average_factor +
73  bias[output_index * static_cast<index_type>(use_bias)] * static_cast<io_t>(use_bias);
74  if constexpr (elem_wise_v == element_op::signed_square) {
75  val[workspace_index] =
76  copysign(val[workspace_index] * val[workspace_index], val[workspace_index]);
77  } else if constexpr (elem_wise_v == element_op::hinge) {
78  val[workspace_index] = io_t(val[workspace_index] > io_t{});
79  } else if constexpr (elem_wise_v == element_op::sigmoid) {
80  val[workspace_index] = io_t{1} / (io_t{1} + exp(-constant * val[workspace_index]));
81  } else if constexpr (elem_wise_v == element_op::exponential) {
82  val[workspace_index] = exp(val[workspace_index] / constant);
83  } else if constexpr (elem_wise_v == element_op::logarithm_one_plus_exp) {
84  val[workspace_index] = log1p(exp(val[workspace_index] / constant));
85  }
86  if constexpr (row_wise_v == row_op::softmax || row_wise_v == row_op::max_index) {
87  auto is_new_max = val[workspace_index] > max_value;
88  max_index = is_new_max * output_index + (!is_new_max) * max_index;
89  max_value = is_new_max * val[workspace_index] + (!is_new_max) * max_value;
90  }
91  }
92 
93  if constexpr (row_wise_v == row_op::max_index) {
94  *out = max_index;
95  } else {
96 #pragma GCC diagnostic push
97 #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
98  auto softmax_normalization = io_t{};
99 #pragma GCC diagnostic pop
100  if constexpr (row_wise_v == row_op::softmax) {
101  for (auto workspace_index = index_type{}; workspace_index < output_count * stride;
102  workspace_index += stride) {
103  val[workspace_index] = exp(val[workspace_index] - max_value);
104  softmax_normalization += val[workspace_index];
105  }
106  }
107 
108  for (auto output_index = index_type{}; output_index < output_count; ++output_index) {
109  auto workspace_index = output_index * stride;
110  if constexpr (row_wise_v == row_op::softmax) {
111  out[output_index] = val[workspace_index] / softmax_normalization;
112  } else {
113  out[output_index] = val[workspace_index];
114  }
115  }
116  }
117 }
118 
119 /*
120  * Struct which holds all data necessary to perform postprocessing on raw
121  * output of a forest model
122  *
123  * @tparam io_t The type used for input and output to/from the model
124  * (typically float/double)
125  * @param row_wise Enum value representing the row-wise post-processing
126  * operation to perform on the output
127  * @param elem_wise Enum value representing the element-wise post-processing
128  * operation to perform on the output
129  * @param average_factor The factor by which to divide during the
130  * normalization step of postprocessing
131  * @param constant If the postprocessing operation requires a constant,
132  * it can be passed here.
133  */
134 template <typename io_t>
137  element_op elem_wise = element_op::disable,
138  io_t average_factor = io_t{1},
139  io_t constant = io_t{1})
140  : average_factor_{average_factor},
141  constant_{constant},
142  row_wise_{row_wise},
143  elem_wise_{elem_wise}
144  {
145  }
146 
147  HOST DEVICE void operator()(infer_kind infer_type,
148  io_t* val,
149  index_type output_count,
150  const io_t* bias,
151  io_t* out,
152  index_type stride = index_type{1}) const
153  {
154  switch (ops_to_val(row_wise_, elem_wise_)) {
157  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
158  break;
161  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
162  break;
165  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
166  break;
169  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
170  break;
173  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
174  break;
176  postprocess<row_op::softmax, element_op::disable>(
177  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
178  break;
181  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
182  break;
184  postprocess<row_op::softmax, element_op::hinge>(
185  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
186  break;
188  postprocess<row_op::softmax, element_op::sigmoid>(
189  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
190  break;
193  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
194  break;
197  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
198  break;
201  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
202  break;
205  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
206  break;
209  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
210  break;
213  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
214  break;
217  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
218  break;
221  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
222  break;
223  default:
224  postprocess<row_op::disable, element_op::disable>(
225  infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
226  }
227  }
228 
229  private:
230  io_t average_factor_;
231  io_t constant_;
232  row_op row_wise_;
233  element_op elem_wise_;
234 };
235 } // namespace fil
236 } // namespace ML
#define DEVICE
Definition: gpu_support.hpp:24
#define HOST
Definition: gpu_support.hpp:23
infer_kind
Definition: infer_kind.hpp:8
row_op
Definition: postproc_ops.hpp:10
HOST DEVICE void postprocess(infer_kind infer_type, io_t *val, index_type output_count, const io_t *bias, io_t *out, index_type stride=index_type{1}, io_t average_factor=io_t{1}, io_t constant=io_t{1})
Definition: postprocessor.hpp:52
element_op
Definition: postproc_ops.hpp:17
HOST DEVICE constexpr auto ops_to_val(row_op row_wise, element_op elem_wise)
Definition: postprocessor.hpp:27
uint32_t index_type
Definition: index_type.hpp:9
Definition: dbscan.hpp:18
Definition: postprocessor.hpp:135
HOST DEVICE void operator()(infer_kind infer_type, io_t *val, index_type output_count, const io_t *bias, io_t *out, index_type stride=index_type{1}) const
Definition: postprocessor.hpp:147
HOST DEVICE postprocessor(row_op row_wise=row_op::disable, element_op elem_wise=element_op::disable, io_t average_factor=io_t{1}, io_t constant=io_t{1})
Definition: postprocessor.hpp:136