forest_model.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
12 #include <cuml/fil/infer_kind.hpp>
13 
14 #include <cuda_runtime.h>
15 
16 #include <cstddef>
17 #include <type_traits>
18 #include <variant>
19 
20 namespace ML {
21 namespace fil {
22 
29 struct forest_model {
32  : decision_forest_{forest}
33  {
34  }
35 
37  auto num_features()
38  {
39  return std::visit([](auto&& concrete_forest) { return concrete_forest.num_features(); },
40  decision_forest_);
41  }
42 
44  auto num_outputs()
45  {
46  return std::visit([](auto&& concrete_forest) { return concrete_forest.num_outputs(); },
47  decision_forest_);
48  }
49 
51  auto num_trees()
52  {
53  return std::visit([](auto&& concrete_forest) { return concrete_forest.num_trees(); },
54  decision_forest_);
55  }
56 
59  {
60  return std::visit([](auto&& concrete_forest) { return concrete_forest.has_vector_leaves(); },
61  decision_forest_);
62  }
63 
66  {
67  return std::visit([](auto&& concrete_forest) { return concrete_forest.row_postprocessing(); },
68  decision_forest_);
69  }
70 
73  {
74  return std::visit(
75  [&val](auto&& concrete_forest) { concrete_forest.set_row_postprocessing(val); },
76  decision_forest_);
77  }
78 
82  {
83  return std::visit([](auto&& concrete_forest) { return concrete_forest.elem_postprocessing(); },
84  decision_forest_);
85  }
86 
88  auto memory_type()
89  {
90  return std::visit([](auto&& concrete_forest) { return concrete_forest.memory_type(); },
91  decision_forest_);
92  }
93 
95  auto device_index()
96  {
97  return std::visit([](auto&& concrete_forest) { return concrete_forest.device_index(); },
98  decision_forest_);
99  }
100 
103  {
104  return std::visit(
105  [](auto&& concrete_forest) {
106  return std::is_same_v<typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
107  double>;
108  },
109  decision_forest_);
110  }
111 
135  template <typename io_t>
137  raft_proto::buffer<io_t> const& input,
139  infer_kind predict_type = infer_kind::default_kind,
140  std::optional<index_type> specified_chunk_size = std::nullopt)
141  {
142  std::visit(
143  [this, predict_type, &output, &input, &stream, &specified_chunk_size](
144  auto&& concrete_forest) {
145  if constexpr (std::is_same_v<
146  typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
147  io_t>) {
148  concrete_forest.predict(output, input, stream, predict_type, specified_chunk_size);
149  } else {
150  throw type_error("Input type does not match model_type");
151  }
152  },
153  decision_forest_);
154  }
155 
183  template <typename io_t>
184  void predict(raft_proto::handle_t const& handle,
185  raft_proto::buffer<io_t>& output,
186  raft_proto::buffer<io_t> const& input,
187  infer_kind predict_type = infer_kind::default_kind,
188  std::optional<index_type> specified_chunk_size = std::nullopt)
189  {
190  std::visit(
191  [this, predict_type, &handle, &output, &input, &specified_chunk_size](
192  auto&& concrete_forest) {
193  using model_io_t = typename std::remove_reference_t<decltype(concrete_forest)>::io_type;
194  if constexpr (std::is_same_v<model_io_t, io_t>) {
195  if (output.memory_type() == memory_type() && input.memory_type() == memory_type()) {
196  concrete_forest.predict(
197  output, input, handle.get_next_usable_stream(), predict_type, specified_chunk_size);
198  } else {
199  auto constexpr static const MIN_CHUNKS_PER_PARTITION = std::size_t{64};
200  auto constexpr static const MAX_CHUNK_SIZE = std::size_t{64};
201 
202  auto row_count = input.size() / num_features();
203  auto partition_size =
205  specified_chunk_size.value_or(MAX_CHUNK_SIZE) * MIN_CHUNKS_PER_PARTITION);
206  auto partition_count = raft_proto::ceildiv(row_count, partition_size);
207  for (auto i = std::size_t{}; i < partition_count; ++i) {
208  auto stream = handle.get_next_usable_stream();
209  auto rows_in_this_partition =
210  std::min(partition_size, row_count - i * partition_size);
211  auto partition_in = raft_proto::buffer<io_t>{};
212  if (input.memory_type() != memory_type()) {
213  partition_in =
214  raft_proto::buffer<io_t>{rows_in_this_partition * num_features(), memory_type()};
215  raft_proto::copy<raft_proto::DEBUG_ENABLED>(partition_in,
216  input,
217  0,
218  i * partition_size * num_features(),
219  partition_in.size(),
220  stream);
221  } else {
222  partition_in =
223  raft_proto::buffer<io_t>{input.data() + i * partition_size * num_features(),
224  rows_in_this_partition * num_features(),
225  memory_type()};
226  }
227  auto partition_out = raft_proto::buffer<io_t>{};
228  if (output.memory_type() != memory_type()) {
229  partition_out =
230  raft_proto::buffer<io_t>{rows_in_this_partition * num_outputs(), memory_type()};
231  } else {
232  partition_out =
233  raft_proto::buffer<io_t>{output.data() + i * partition_size * num_outputs(),
234  rows_in_this_partition * num_outputs(),
235  memory_type()};
236  }
237  concrete_forest.predict(
238  partition_out, partition_in, stream, predict_type, specified_chunk_size);
239  if (output.memory_type() != memory_type()) {
240  raft_proto::copy<raft_proto::DEBUG_ENABLED>(output,
241  partition_out,
242  i * partition_size * num_outputs(),
243  0,
244  partition_out.size(),
245  stream);
246  }
247  }
248  }
249  } else {
250  throw type_error("Input type does not match model_type");
251  }
252  },
253  decision_forest_);
254  }
255 
282  template <typename io_t>
283  void predict(raft_proto::handle_t const& handle,
284  io_t* output,
285  io_t* input,
286  std::size_t num_rows,
287  raft_proto::device_type out_mem_type,
288  raft_proto::device_type in_mem_type,
289  infer_kind predict_type = infer_kind::default_kind,
290  std::optional<index_type> specified_chunk_size = std::nullopt)
291  {
292  int current_device_id;
293  raft_proto::cuda_check(cudaGetDevice(¤t_device_id));
294  auto out_buffer =
295  raft_proto::buffer{output, num_rows * num_outputs(), out_mem_type, current_device_id};
296  auto in_buffer =
297  raft_proto::buffer{input, num_rows * num_features(), in_mem_type, current_device_id};
298  predict(handle, out_buffer, in_buffer, predict_type, specified_chunk_size);
299  }
300 
301  private:
302  decision_forest_variant decision_forest_;
303 };
304 
305 } // namespace fil
306 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
infer_kind
Definition: infer_kind.hpp:8
row_op
Definition: postproc_ops.hpp:10
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:425
Definition: dbscan.hpp:18
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:10
int cuda_stream
Definition: cuda_stream.hpp:14
void cuda_check(error_t const &err) noexcept(!GPU_ENABLED)
Definition: cuda_check.hpp:15
device_type
Definition: device_type.hpp:7
Definition: forest_model.hpp:29
auto row_postprocessing()
Definition: forest_model.hpp:65
auto num_features()
Definition: forest_model.hpp:37
void predict(raft_proto::handle_t const &handle, io_t *output, io_t *input, std::size_t num_rows, raft_proto::device_type out_mem_type, raft_proto::device_type in_mem_type, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:283
void predict(raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, raft_proto::cuda_stream stream=raft_proto::cuda_stream{}, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:136
auto num_trees()
Definition: forest_model.hpp:51
auto num_outputs()
Definition: forest_model.hpp:44
forest_model(decision_forest_variant &&forest=decision_forest_variant{})
Definition: forest_model.hpp:31
auto elem_postprocessing()
Definition: forest_model.hpp:81
void predict(raft_proto::handle_t const &handle, raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:184
auto memory_type()
Definition: forest_model.hpp:88
auto has_vector_leaves()
Definition: forest_model.hpp:58
void set_row_postprocessing(row_op val)
Definition: forest_model.hpp:72
auto device_index()
Definition: forest_model.hpp:95
auto is_double_precision()
Definition: forest_model.hpp:102
Definition: forest.hpp:24
Definition: exceptions.hpp:40
A container which may or may not own its own data on host or device.
Definition: buffer.hpp:30
auto size() const noexcept
Definition: buffer.hpp:282
HOST DEVICE auto * data() const noexcept
Definition: buffer.hpp:283
auto memory_type() const noexcept
Definition: buffer.hpp:284
Definition: handle.hpp:36
auto get_usable_stream_count() const
Definition: handle.hpp:39
auto get_next_usable_stream() const
Definition: handle.hpp:37