forest_model.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
23 #include <cuml/fil/infer_kind.hpp>
24 
25 #include <cuda_runtime.h>
26 
27 #include <cstddef>
28 #include <type_traits>
29 #include <variant>
30 
31 namespace ML {
32 namespace fil {
33 
40 struct forest_model {
43  : decision_forest_{forest}
44  {
45  }
46 
48  auto num_features()
49  {
50  return std::visit([](auto&& concrete_forest) { return concrete_forest.num_features(); },
51  decision_forest_);
52  }
53 
55  auto num_outputs()
56  {
57  return std::visit([](auto&& concrete_forest) { return concrete_forest.num_outputs(); },
58  decision_forest_);
59  }
60 
62  auto num_trees()
63  {
64  return std::visit([](auto&& concrete_forest) { return concrete_forest.num_trees(); },
65  decision_forest_);
66  }
67 
70  {
71  return std::visit([](auto&& concrete_forest) { return concrete_forest.has_vector_leaves(); },
72  decision_forest_);
73  }
74 
77  {
78  return std::visit([](auto&& concrete_forest) { return concrete_forest.row_postprocessing(); },
79  decision_forest_);
80  }
81 
84  {
85  return std::visit(
86  [&val](auto&& concrete_forest) { concrete_forest.set_row_postprocessing(val); },
87  decision_forest_);
88  }
89 
93  {
94  return std::visit([](auto&& concrete_forest) { return concrete_forest.elem_postprocessing(); },
95  decision_forest_);
96  }
97 
99  auto memory_type()
100  {
101  return std::visit([](auto&& concrete_forest) { return concrete_forest.memory_type(); },
102  decision_forest_);
103  }
104 
107  {
108  return std::visit([](auto&& concrete_forest) { return concrete_forest.device_index(); },
109  decision_forest_);
110  }
111 
114  {
115  return std::visit(
116  [](auto&& concrete_forest) {
117  return std::is_same_v<typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
118  double>;
119  },
120  decision_forest_);
121  }
122 
146  template <typename io_t>
148  raft_proto::buffer<io_t> const& input,
150  infer_kind predict_type = infer_kind::default_kind,
151  std::optional<index_type> specified_chunk_size = std::nullopt)
152  {
153  std::visit(
154  [this, predict_type, &output, &input, &stream, &specified_chunk_size](
155  auto&& concrete_forest) {
156  if constexpr (std::is_same_v<
157  typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
158  io_t>) {
159  concrete_forest.predict(output, input, stream, predict_type, specified_chunk_size);
160  } else {
161  throw type_error("Input type does not match model_type");
162  }
163  },
164  decision_forest_);
165  }
166 
194  template <typename io_t>
195  void predict(raft_proto::handle_t const& handle,
196  raft_proto::buffer<io_t>& output,
197  raft_proto::buffer<io_t> const& input,
198  infer_kind predict_type = infer_kind::default_kind,
199  std::optional<index_type> specified_chunk_size = std::nullopt)
200  {
201  std::visit(
202  [this, predict_type, &handle, &output, &input, &specified_chunk_size](
203  auto&& concrete_forest) {
204  using model_io_t = typename std::remove_reference_t<decltype(concrete_forest)>::io_type;
205  if constexpr (std::is_same_v<model_io_t, io_t>) {
206  if (output.memory_type() == memory_type() && input.memory_type() == memory_type()) {
207  concrete_forest.predict(
208  output, input, handle.get_next_usable_stream(), predict_type, specified_chunk_size);
209  } else {
210  auto constexpr static const MIN_CHUNKS_PER_PARTITION = std::size_t{64};
211  auto constexpr static const MAX_CHUNK_SIZE = std::size_t{64};
212 
213  auto row_count = input.size() / num_features();
214  auto partition_size =
216  specified_chunk_size.value_or(MAX_CHUNK_SIZE) * MIN_CHUNKS_PER_PARTITION);
217  auto partition_count = raft_proto::ceildiv(row_count, partition_size);
218  for (auto i = std::size_t{}; i < partition_count; ++i) {
219  auto stream = handle.get_next_usable_stream();
220  auto rows_in_this_partition =
221  std::min(partition_size, row_count - i * partition_size);
222  auto partition_in = raft_proto::buffer<io_t>{};
223  if (input.memory_type() != memory_type()) {
224  partition_in =
225  raft_proto::buffer<io_t>{rows_in_this_partition * num_features(), memory_type()};
226  raft_proto::copy<raft_proto::DEBUG_ENABLED>(partition_in,
227  input,
228  0,
229  i * partition_size * num_features(),
230  partition_in.size(),
231  stream);
232  } else {
233  partition_in =
234  raft_proto::buffer<io_t>{input.data() + i * partition_size * num_features(),
235  rows_in_this_partition * num_features(),
236  memory_type()};
237  }
238  auto partition_out = raft_proto::buffer<io_t>{};
239  if (output.memory_type() != memory_type()) {
240  partition_out =
241  raft_proto::buffer<io_t>{rows_in_this_partition * num_outputs(), memory_type()};
242  } else {
243  partition_out =
244  raft_proto::buffer<io_t>{output.data() + i * partition_size * num_outputs(),
245  rows_in_this_partition * num_outputs(),
246  memory_type()};
247  }
248  concrete_forest.predict(
249  partition_out, partition_in, stream, predict_type, specified_chunk_size);
250  if (output.memory_type() != memory_type()) {
251  raft_proto::copy<raft_proto::DEBUG_ENABLED>(output,
252  partition_out,
253  i * partition_size * num_outputs(),
254  0,
255  partition_out.size(),
256  stream);
257  }
258  }
259  }
260  } else {
261  throw type_error("Input type does not match model_type");
262  }
263  },
264  decision_forest_);
265  }
266 
293  template <typename io_t>
294  void predict(raft_proto::handle_t const& handle,
295  io_t* output,
296  io_t* input,
297  std::size_t num_rows,
298  raft_proto::device_type out_mem_type,
299  raft_proto::device_type in_mem_type,
300  infer_kind predict_type = infer_kind::default_kind,
301  std::optional<index_type> specified_chunk_size = std::nullopt)
302  {
303  int current_device_id;
304  raft_proto::cuda_check(cudaGetDevice(¤t_device_id));
305  auto out_buffer =
306  raft_proto::buffer{output, num_rows * num_outputs(), out_mem_type, current_device_id};
307  auto in_buffer =
308  raft_proto::buffer{input, num_rows * num_features(), in_mem_type, current_device_id};
309  predict(handle, out_buffer, in_buffer, predict_type, specified_chunk_size);
310  }
311 
312  private:
313  decision_forest_variant decision_forest_;
314 };
315 
316 } // namespace fil
317 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
infer_kind
Definition: infer_kind.hpp:19
row_op
Definition: postproc_ops.hpp:21
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:431
Definition: dbscan.hpp:29
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:21
int cuda_stream
Definition: cuda_stream.hpp:25
void cuda_check(error_t const &err) noexcept(!GPU_ENABLED)
Definition: cuda_check.hpp:26
device_type
Definition: device_type.hpp:18
Definition: forest_model.hpp:40
auto row_postprocessing()
Definition: forest_model.hpp:76
auto num_features()
Definition: forest_model.hpp:48
void predict(raft_proto::handle_t const &handle, io_t *output, io_t *input, std::size_t num_rows, raft_proto::device_type out_mem_type, raft_proto::device_type in_mem_type, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:294
void predict(raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, raft_proto::cuda_stream stream=raft_proto::cuda_stream{}, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:147
auto num_trees()
Definition: forest_model.hpp:62
auto num_outputs()
Definition: forest_model.hpp:55
forest_model(decision_forest_variant &&forest=decision_forest_variant{})
Definition: forest_model.hpp:42
auto elem_postprocessing()
Definition: forest_model.hpp:92
void predict(raft_proto::handle_t const &handle, raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:195
auto memory_type()
Definition: forest_model.hpp:99
auto has_vector_leaves()
Definition: forest_model.hpp:69
void set_row_postprocessing(row_op val)
Definition: forest_model.hpp:83
auto device_index()
Definition: forest_model.hpp:106
auto is_double_precision()
Definition: forest_model.hpp:113
Definition: forest.hpp:35
Definition: exceptions.hpp:51
A container which may or may not own its own data on host or device.
Definition: buffer.hpp:41
auto size() const noexcept
Definition: buffer.hpp:293
HOST DEVICE auto * data() const noexcept
Definition: buffer.hpp:294
auto memory_type() const noexcept
Definition: buffer.hpp:295
Definition: handle.hpp:47
auto get_usable_stream_count() const
Definition: handle.hpp:50
auto get_next_usable_stream() const
Definition: handle.hpp:48