15 #include <type_traits>
29 return (
static_cast<std::underlying_type_t<row_op>
>(row_wise) |
30 static_cast<std::underlying_type_t<element_op>
>(elem_wise));
51 template <row_op row_wise_v, element_op elem_wise_v,
typename io_t>
58 io_t average_factor = io_t{1},
59 io_t constant = io_t{1})
62 #pragma GCC diagnostic push
63 #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
65 auto max_value = std::numeric_limits<io_t>::lowest();
66 #pragma GCC diagnostic pop
67 for (
auto output_index =
index_type{}; output_index < output_count; ++output_index) {
68 auto workspace_index = output_index * stride;
71 val[workspace_index] =
72 val[workspace_index] / average_factor +
73 bias[output_index *
static_cast<index_type>(use_bias)] *
static_cast<io_t
>(use_bias);
75 val[workspace_index] =
76 copysign(val[workspace_index] * val[workspace_index], val[workspace_index]);
78 val[workspace_index] = io_t(val[workspace_index] > io_t{});
80 val[workspace_index] = io_t{1} / (io_t{1} + exp(-constant * val[workspace_index]));
82 val[workspace_index] = exp(val[workspace_index] / constant);
84 val[workspace_index] = log1p(exp(val[workspace_index] / constant));
87 auto is_new_max = val[workspace_index] > max_value;
88 max_index = is_new_max * output_index + (!is_new_max) * max_index;
89 max_value = is_new_max * val[workspace_index] + (!is_new_max) * max_value;
96 #pragma GCC diagnostic push
97 #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
98 auto softmax_normalization = io_t{};
99 #pragma GCC diagnostic pop
101 for (
auto workspace_index =
index_type{}; workspace_index < output_count * stride;
102 workspace_index += stride) {
103 val[workspace_index] = exp(val[workspace_index] - max_value);
104 softmax_normalization += val[workspace_index];
108 for (
auto output_index =
index_type{}; output_index < output_count; ++output_index) {
109 auto workspace_index = output_index * stride;
111 out[output_index] = val[workspace_index] / softmax_normalization;
113 out[output_index] = val[workspace_index];
134 template <
typename io_t>
138 io_t average_factor = io_t{1},
139 io_t constant = io_t{1})
140 : average_factor_{average_factor},
143 elem_wise_{elem_wise}
157 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
161 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
165 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
169 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
173 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
177 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
181 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
185 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
189 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
193 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
197 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
201 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
205 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
209 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
213 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
217 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
221 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
224 postprocess<row_op::disable, element_op::disable>(
225 infer_type, val, output_count, bias, out, stride, average_factor_, constant_);
230 io_t average_factor_;
#define DEVICE
Definition: gpu_support.hpp:24
#define HOST
Definition: gpu_support.hpp:23
infer_kind
Definition: infer_kind.hpp:8
row_op
Definition: postproc_ops.hpp:10
HOST DEVICE void postprocess(infer_kind infer_type, io_t *val, index_type output_count, const io_t *bias, io_t *out, index_type stride=index_type{1}, io_t average_factor=io_t{1}, io_t constant=io_t{1})
Definition: postprocessor.hpp:52
element_op
Definition: postproc_ops.hpp:17
HOST DEVICE constexpr auto ops_to_val(row_op row_wise, element_op elem_wise)
Definition: postprocessor.hpp:27
uint32_t index_type
Definition: index_type.hpp:9
Definition: dbscan.hpp:18
Definition: postprocessor.hpp:135
HOST DEVICE void operator()(infer_kind infer_type, io_t *val, index_type output_count, const io_t *bias, io_t *out, index_type stride=index_type{1}) const
Definition: postprocessor.hpp:147
HOST DEVICE postprocessor(row_op row_wise=row_op::disable, element_op elem_wise=element_op::disable, io_t average_factor=io_t{1}, io_t constant=io_t{1})
Definition: postprocessor.hpp:136