Loading [MathJax]/extensions/tex2jax.js
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
cpu.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
23 
24 #ifdef _OPENMP
25 #include <omp.h>
26 #else
27 #ifdef omp_get_max_threads
28 #if omp_get_max_threads() != 1
29 #error "Inconsistent placeholders for omp_get_max_threads"
30 #endif
31 #else
32 #define omp_get_max_threads() 1
33 #endif
34 #endif
35 
36 #include <algorithm>
37 #include <cstddef>
38 #include <iostream>
39 #include <new>
40 #include <numeric>
41 #include <vector>
42 
43 namespace ML {
44 namespace experimental {
45 namespace fil {
46 namespace detail {
47 
82 template <bool has_categorical_nodes,
83  bool predict_leaf,
84  typename forest_t,
85  typename vector_output_t = std::nullptr_t,
86  typename categorical_data_t = std::nullptr_t>
89  typename forest_t::io_type* output,
90  typename forest_t::io_type const* input,
91  index_type row_count,
92  index_type col_count,
93  index_type num_outputs,
94  index_type chunk_size = hardware_constructive_interference_size,
95  index_type grove_size = hardware_constructive_interference_size,
96  vector_output_t vector_output_p = nullptr,
97  categorical_data_t categorical_data = nullptr,
99 {
100  auto constexpr has_vector_leaves = !std::is_same_v<vector_output_t, std::nullptr_t>;
101  auto constexpr has_nonlocal_categories = !std::is_same_v<categorical_data_t, std::nullptr_t>;
102 
103  using node_t = typename forest_t::node_type;
104 
105  using output_t = typename forest_t::template raw_output_type<vector_output_t>;
106 
107  auto const num_tree = forest.tree_count();
108  auto const num_grove = raft_proto::ceildiv(num_tree, grove_size);
109  auto const num_chunk = raft_proto::ceildiv(row_count, chunk_size);
110 
111  auto output_workspace = std::vector<output_t>(row_count * num_outputs * num_grove, output_t{});
112  auto const task_count = num_grove * num_chunk;
113 
114 #pragma omp parallel num_threads(std::min(index_type(omp_get_max_threads()), task_count))
115  {
116  // Infer on each grove and chunk
117 #pragma omp for
118  for (auto task_index = index_type{}; task_index < task_count; ++task_index) {
119  auto const grove_index = task_index / num_chunk;
120  auto const chunk_index = task_index % num_chunk;
121  auto const start_row = chunk_index * chunk_size;
122  auto const end_row = std::min(start_row + chunk_size, row_count);
123  auto const start_tree = grove_index * grove_size;
124  auto const end_tree = std::min(start_tree + grove_size, num_tree);
125 
126  for (auto row_index = start_row; row_index < end_row; ++row_index) {
127  for (auto tree_index = start_tree; tree_index < end_tree; ++tree_index) {
128  auto tree_output =
129  std::conditional_t<predict_leaf,
130  index_type,
131  std::conditional_t<has_vector_leaves,
132  typename node_t::index_type,
133  typename node_t::threshold_type>>{};
134  tree_output = evaluate_tree<has_vector_leaves,
135  has_categorical_nodes,
136  has_nonlocal_categories,
137  predict_leaf>(
138  forest, tree_index, input + row_index * col_count, categorical_data);
139  if constexpr (predict_leaf) {
140  output_workspace[row_index * num_outputs * num_grove + tree_index * num_grove +
141  grove_index] = static_cast<typename forest_t::io_type>(tree_output);
142  } else {
143  auto const default_num_outputs = forest.num_outputs();
144  if constexpr (has_vector_leaves) {
145  auto output_offset = (row_index * num_outputs * num_grove +
146  tree_index * default_num_outputs * num_grove *
147  (infer_type == infer_kind::per_tree) +
148  grove_index);
149  for (auto output_index = index_type{}; output_index < default_num_outputs;
150  ++output_index) {
151  output_workspace[output_offset + output_index * num_grove] +=
152  vector_output_p[tree_output * default_num_outputs + output_index];
153  }
154  } else {
155  auto output_offset =
156  (row_index * num_outputs * num_grove +
157  (tree_index % default_num_outputs) * num_grove *
158  (infer_type == infer_kind::default_kind) +
159  tree_index * num_grove * (infer_type == infer_kind::per_tree) + grove_index);
160  output_workspace[output_offset] += tree_output;
161  }
162  }
163  } // Trees
164  } // Rows
165  } // Tasks
166 
167  // Sum over grove and postprocess
168 #pragma omp for
169  for (auto row_index = index_type{}; row_index < row_count; ++row_index) {
170  for (auto output_index = index_type{}; output_index < num_outputs; ++output_index) {
171  auto grove_offset = (row_index * num_outputs * num_grove + output_index * num_grove);
172 
173  output_workspace[grove_offset] =
174  std::accumulate(std::begin(output_workspace) + grove_offset,
175  std::begin(output_workspace) + grove_offset + num_grove,
176  output_t{});
177  }
178  postproc(output_workspace.data() + row_index * num_outputs * num_grove,
179  num_outputs,
180  output + row_index * num_outputs,
181  num_grove);
182  }
183  } // End omp parallel
184 }
185 
186 } // namespace detail
187 } // namespace fil
188 } // namespace experimental
189 } // namespace ML
HOST DEVICE auto evaluate_tree(forest_t const &forest, index_type tree_index, io_t const *__restrict__ row, categorical_data_t categorical_data)
Definition: evaluate_tree.hpp:174
void infer_kernel_cpu(forest_t const &forest, postprocessor< typename forest_t::io_type > const &postproc, typename forest_t::io_type *output, typename forest_t::io_type const *input, index_type row_count, index_type col_count, index_type num_outputs, index_type chunk_size=hardware_constructive_interference_size, index_type grove_size=hardware_constructive_interference_size, vector_output_t vector_output_p=nullptr, categorical_data_t categorical_data=nullptr, infer_kind infer_type=infer_kind::default_kind)
Definition: cpu.hpp:87
uint32_t index_type
Definition: index_type.hpp:21
infer_kind
Definition: infer_kind.hpp:20
forest< real_t > * forest_t
Definition: fil.h:89
Definition: dbscan.hpp:30
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:21
Definition: forest.hpp:36
HOST DEVICE auto tree_count() const
Definition: forest.hpp:68
HOST DEVICE auto num_outputs() const
Definition: forest.hpp:72
Definition: postprocessor.hpp:141