25 #include <type_traits>
32 namespace experimental {
40 return (
static_cast<std::underlying_type_t<row_op>
>(row_wise) |
41 static_cast<std::underlying_type_t<element_op>
>(elem_wise));
61 template <row_op row_wise_v, element_op elem_wise_v,
typename io_t>
66 io_t average_factor = io_t{1},
68 io_t constant = io_t{1})
70 #pragma GCC diagnostic push
71 #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
73 auto max_value = std::numeric_limits<io_t>::lowest();
74 #pragma GCC diagnostic pop
75 for (
auto output_index =
index_type{}; output_index < output_count; ++output_index) {
76 auto workspace_index = output_index * stride;
77 val[workspace_index] = val[workspace_index] / average_factor + bias;
79 val[workspace_index] =
80 copysign(val[workspace_index] * val[workspace_index], val[workspace_index]);
82 val[workspace_index] = io_t(val[workspace_index] > io_t{});
84 val[workspace_index] = io_t{1} / (io_t{1} + exp(-constant * val[workspace_index]));
86 val[workspace_index] = exp(val[workspace_index] / constant);
88 val[workspace_index] = log1p(exp(val[workspace_index] / constant));
91 auto is_new_max = val[workspace_index] > max_value;
92 max_index = is_new_max * output_index + (!is_new_max) * max_index;
93 max_value = is_new_max * val[workspace_index] + (!is_new_max) * max_value;
100 #pragma GCC diagnostic push
101 #pragma GCC diagnostic ignored "-Wunused-but-set-variable"
102 auto softmax_normalization = io_t{};
103 #pragma GCC diagnostic pop
105 for (
auto workspace_index =
index_type{}; workspace_index < output_count * stride;
106 workspace_index += stride) {
107 val[workspace_index] = exp(val[workspace_index] - max_value);
108 softmax_normalization += val[workspace_index];
112 for (
auto output_index =
index_type{}; output_index < output_count; ++output_index) {
113 auto workspace_index = output_index * stride;
115 out[output_index] = val[workspace_index] / softmax_normalization;
117 out[output_index] = val[workspace_index];
140 template <
typename io_t>
144 io_t average_factor = io_t{1},
146 io_t constant = io_t{1})
147 : average_factor_{average_factor},
151 elem_wise_{elem_wise}
163 val, output_count, out, stride, average_factor_, bias_, constant_);
167 val, output_count, out, stride, average_factor_, bias_, constant_);
171 val, output_count, out, stride, average_factor_, bias_, constant_);
175 val, output_count, out, stride, average_factor_, bias_, constant_);
179 val, output_count, out, stride, average_factor_, bias_, constant_);
183 val, output_count, out, stride, average_factor_, bias_, constant_);
187 val, output_count, out, stride, average_factor_, bias_, constant_);
191 val, output_count, out, stride, average_factor_, bias_, constant_);
195 val, output_count, out, stride, average_factor_, bias_, constant_);
199 val, output_count, out, stride, average_factor_, bias_, constant_);
203 val, output_count, out, stride, average_factor_, bias_, constant_);
207 val, output_count, out, stride, average_factor_, bias_, constant_);
211 val, output_count, out, stride, average_factor_, bias_, constant_);
215 val, output_count, out, stride, average_factor_, bias_, constant_);
219 val, output_count, out, stride, average_factor_, bias_, constant_);
223 val, output_count, out, stride, average_factor_, bias_, constant_);
227 val, output_count, out, stride, average_factor_, bias_, constant_);
230 postprocess<row_op::disable, element_op::disable>(
231 val, output_count, out, stride, average_factor_, bias_, constant_);
236 io_t average_factor_;
#define DEVICE
Definition: gpu_support.hpp:35
#define HOST
Definition: gpu_support.hpp:34
element_op
Definition: postproc_ops.hpp:29
uint32_t index_type
Definition: index_type.hpp:21
HOST DEVICE constexpr auto ops_to_val(row_op row_wise, element_op elem_wise)
Definition: postprocessor.hpp:38
row_op
Definition: postproc_ops.hpp:22
HOST DEVICE void postprocess(io_t *val, index_type output_count, io_t *out, index_type stride=index_type{1}, io_t average_factor=io_t{1}, io_t bias=io_t{0}, io_t constant=io_t{1})
Definition: postprocessor.hpp:62
Definition: dbscan.hpp:30
Definition: postprocessor.hpp:141
HOST DEVICE postprocessor(row_op row_wise=row_op::disable, element_op elem_wise=element_op::disable, io_t average_factor=io_t{1}, io_t bias=io_t{0}, io_t constant=io_t{1})
Definition: postprocessor.hpp:142
HOST DEVICE void operator()(io_t *val, index_type output_count, io_t *out, index_type stride=index_type{1}) const
Definition: postprocessor.hpp:155