14 #include <raft/core/error.hpp>
19 #ifdef omp_get_max_threads
20 #if omp_get_max_threads() != 1
21 #error "Inconsistent placeholders for omp_get_max_threads"
24 #define omp_get_max_threads() 1
75 template <
bool has_categorical_nodes,
78 typename vector_output_t = std::nullptr_t,
79 typename categorical_data_t = std::nullptr_t>
82 typename forest_t::io_type* output,
83 typename forest_t::io_type
const* input,
87 index_type chunk_size = hardware_constructive_interference_size,
88 index_type grove_size = hardware_constructive_interference_size,
89 vector_output_t vector_output_p =
nullptr,
90 categorical_data_t categorical_data =
nullptr,
93 auto constexpr has_vector_leaves = !std::is_same_v<vector_output_t, std::nullptr_t>;
94 auto constexpr has_nonlocal_categories = !std::is_same_v<categorical_data_t, std::nullptr_t>;
96 using node_t =
typename forest_t::node_type;
98 using output_t =
typename forest_t::template raw_output_type<vector_output_t>;
112 (num_outputs *
static_cast<std::uint64_t
>(num_grove));
113 if (max_num_row >= 3) {
118 if (row_count > max_num_row) {
119 throw runtime_error(std::string(
"Input size too large! Input should be at most ") +
120 std::to_string(max_num_row) +
".");
124 auto output_workspace = std::vector<output_t>(row_count * num_outputs * num_grove, output_t{});
125 auto const task_count = num_grove * num_chunk;
127 #pragma omp parallel num_threads(std::min(index_type(omp_get_max_threads()), task_count))
131 for (
auto task_index =
index_type{}; task_index < task_count; ++task_index) {
132 auto const grove_index = task_index / num_chunk;
133 auto const chunk_index = task_index % num_chunk;
134 auto const start_row = chunk_index * chunk_size;
135 auto const end_row = std::min(start_row + chunk_size, row_count);
136 auto const start_tree = grove_index * grove_size;
137 auto const end_tree = std::min(start_tree + grove_size, num_tree);
139 for (
auto row_index = start_row; row_index < end_row; ++row_index) {
140 for (
auto tree_index = start_tree; tree_index < end_tree; ++tree_index) {
142 std::conditional_t<predict_leaf,
144 std::conditional_t<has_vector_leaves,
146 typename node_t::threshold_type>>{};
148 has_categorical_nodes,
149 has_nonlocal_categories,
151 forest, tree_index, input + row_index * col_count, categorical_data);
152 if constexpr (predict_leaf) {
153 output_workspace[row_index * num_outputs * num_grove + tree_index * num_grove +
154 grove_index] =
static_cast<typename forest_t::io_type
>(tree_output);
157 if constexpr (has_vector_leaves) {
158 auto output_offset = (row_index * num_outputs * num_grove +
159 tree_index * default_num_outputs * num_grove *
162 for (
auto output_index =
index_type{}; output_index < default_num_outputs;
164 output_workspace[output_offset + output_index * num_grove] +=
165 vector_output_p[tree_output * default_num_outputs + output_index];
169 (row_index * num_outputs * num_grove +
170 (tree_index % default_num_outputs) * num_grove *
173 output_workspace[output_offset] += tree_output;
182 for (
auto row_index =
index_type{}; row_index < row_count; ++row_index) {
183 for (
auto output_index =
index_type{}; output_index < num_outputs; ++output_index) {
184 auto grove_offset = (row_index * num_outputs * num_grove + output_index * num_grove);
186 output_workspace[grove_offset] =
187 std::accumulate(std::begin(output_workspace) + grove_offset,
188 std::begin(output_workspace) + grove_offset + num_grove,
192 output_workspace.data() + row_index * num_outputs * num_grove,
195 output + row_index * num_outputs,
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
void infer_kernel_cpu(forest_t const &forest, postprocessor< typename forest_t::io_type > const &postproc, typename forest_t::io_type *output, typename forest_t::io_type const *input, index_type row_count, index_type col_count, index_type num_outputs, index_type chunk_size=hardware_constructive_interference_size, index_type grove_size=hardware_constructive_interference_size, vector_output_t vector_output_p=nullptr, categorical_data_t categorical_data=nullptr, infer_kind infer_type=infer_kind::default_kind)
Definition: cpu.hpp:80
HOST DEVICE auto evaluate_tree(forest_t const &forest, index_type tree_index, io_t const *__restrict__ row, categorical_data_t categorical_data)
Definition: evaluate_tree.hpp:162
infer_kind
Definition: infer_kind.hpp:8
uint32_t index_type
Definition: index_type.hpp:9
Definition: dbscan.hpp:18
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:10
Definition: forest.hpp:24
HOST DEVICE auto num_outputs() const
Definition: forest.hpp:65
HOST DEVICE auto tree_count() const
Definition: forest.hpp:61
HOST DEVICE const auto * bias() const
Definition: forest.hpp:58
Definition: postprocessor.hpp:135
Definition: exceptions.hpp:52