25 #include <type_traits> 
   39   return (
static_cast<std::underlying_type_t<row_op>
>(row_wise) |
 
   40           static_cast<std::underlying_type_t<element_op>
>(elem_wise));
 
   60 template <row_op row_wise_v, element_op elem_wise_v, 
typename io_t>
 
   65                              io_t average_factor = io_t{1},
 
   67                              io_t constant       = io_t{1})
 
   69 #pragma GCC diagnostic push 
   70 #pragma GCC diagnostic ignored "-Wunused-but-set-variable" 
   72   auto max_value = std::numeric_limits<io_t>::lowest();
 
   73 #pragma GCC diagnostic pop 
   74   for (
auto output_index = 
index_type{}; output_index < output_count; ++output_index) {
 
   75     auto workspace_index = output_index * stride;
 
   76     val[workspace_index] = val[workspace_index] / average_factor + bias;
 
   78       val[workspace_index] =
 
   79         copysign(val[workspace_index] * val[workspace_index], val[workspace_index]);
 
   81       val[workspace_index] = io_t(val[workspace_index] > io_t{});
 
   83       val[workspace_index] = io_t{1} / (io_t{1} + exp(-constant * val[workspace_index]));
 
   85       val[workspace_index] = exp(val[workspace_index] / constant);
 
   87       val[workspace_index] = log1p(exp(val[workspace_index] / constant));
 
   90       auto is_new_max = val[workspace_index] > max_value;
 
   91       max_index       = is_new_max * output_index + (!is_new_max) * max_index;
 
   92       max_value       = is_new_max * val[workspace_index] + (!is_new_max) * max_value;
 
   99 #pragma GCC diagnostic push 
  100 #pragma GCC diagnostic ignored "-Wunused-but-set-variable" 
  101     auto softmax_normalization = io_t{};
 
  102 #pragma GCC diagnostic pop 
  104       for (
auto workspace_index = 
index_type{}; workspace_index < output_count * stride;
 
  105            workspace_index += stride) {
 
  106         val[workspace_index] = exp(val[workspace_index] - max_value);
 
  107         softmax_normalization += val[workspace_index];
 
  111     for (
auto output_index = 
index_type{}; output_index < output_count; ++output_index) {
 
  112       auto workspace_index = output_index * stride;
 
  114         out[output_index] = val[workspace_index] / softmax_normalization;
 
  116         out[output_index] = val[workspace_index];
 
  139 template <
typename io_t>
 
  143                             io_t average_factor  = io_t{1},
 
  145                             io_t constant        = io_t{1})
 
  146     : average_factor_{average_factor},
 
  150       elem_wise_{elem_wise}
 
  162           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  166           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  170           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  174           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  178           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  182           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  186           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  190           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  194           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  198           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  202           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  206           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  210           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  214           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  218           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  222           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  226           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  229         postprocess<row_op::disable, element_op::disable>(
 
  230           val, output_count, out, stride, average_factor_, bias_, constant_);
 
  235   io_t average_factor_;
 
#define DEVICE
Definition: gpu_support.hpp:35
 
#define HOST
Definition: gpu_support.hpp:34
 
row_op
Definition: postproc_ops.hpp:21
 
element_op
Definition: postproc_ops.hpp:28
 
HOST DEVICE constexpr auto ops_to_val(row_op row_wise, element_op elem_wise)
Definition: postprocessor.hpp:37
 
HOST DEVICE void postprocess(io_t *val, index_type output_count, io_t *out, index_type stride=index_type{1}, io_t average_factor=io_t{1}, io_t bias=io_t{0}, io_t constant=io_t{1})
Definition: postprocessor.hpp:61
 
uint32_t index_type
Definition: index_type.hpp:20
 
Definition: dbscan.hpp:29
 
Definition: postprocessor.hpp:140
 
HOST DEVICE postprocessor(row_op row_wise=row_op::disable, element_op elem_wise=element_op::disable, io_t average_factor=io_t{1}, io_t bias=io_t{0}, io_t constant=io_t{1})
Definition: postprocessor.hpp:141
 
HOST DEVICE void operator()(io_t *val, index_type output_count, io_t *out, index_type stride=index_type{1}) const
Definition: postprocessor.hpp:154