forest_model.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
23 
24 #include <cstddef>
25 #include <type_traits>
26 #include <variant>
27 
28 namespace ML {
29 namespace experimental {
30 namespace fil {
31 
38 struct forest_model {
41  : decision_forest_{forest}
42  {
43  }
44 
46  auto num_features()
47  {
48  return std::visit([](auto&& concrete_forest) { return concrete_forest.num_features(); },
49  decision_forest_);
50  }
51 
53  auto num_outputs()
54  {
55  return std::visit([](auto&& concrete_forest) { return concrete_forest.num_outputs(); },
56  decision_forest_);
57  }
58 
60  auto num_trees()
61  {
62  return std::visit([](auto&& concrete_forest) { return concrete_forest.num_trees(); },
63  decision_forest_);
64  }
65 
68  {
69  return std::visit([](auto&& concrete_forest) { return concrete_forest.has_vector_leaves(); },
70  decision_forest_);
71  }
72 
75  {
76  return std::visit([](auto&& concrete_forest) { return concrete_forest.row_postprocessing(); },
77  decision_forest_);
78  }
79 
82  {
83  return std::visit(
84  [&val](auto&& concrete_forest) { concrete_forest.set_row_postprocessing(val); },
85  decision_forest_);
86  }
87 
91  {
92  return std::visit([](auto&& concrete_forest) { return concrete_forest.elem_postprocessing(); },
93  decision_forest_);
94  }
95 
97  auto memory_type()
98  {
99  return std::visit([](auto&& concrete_forest) { return concrete_forest.memory_type(); },
100  decision_forest_);
101  }
102 
105  {
106  return std::visit([](auto&& concrete_forest) { return concrete_forest.device_index(); },
107  decision_forest_);
108  }
109 
112  {
113  return std::visit(
114  [](auto&& concrete_forest) {
115  return std::is_same_v<typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
116  double>;
117  },
118  decision_forest_);
119  }
120 
144  template <typename io_t>
146  raft_proto::buffer<io_t> const& input,
148  infer_kind predict_type = infer_kind::default_kind,
149  std::optional<index_type> specified_chunk_size = std::nullopt)
150  {
151  std::visit(
152  [this, predict_type, &output, &input, &stream, &specified_chunk_size](
153  auto&& concrete_forest) {
154  if constexpr (std::is_same_v<
155  typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
156  io_t>) {
157  concrete_forest.predict(output, input, stream, predict_type, specified_chunk_size);
158  } else {
159  throw type_error("Input type does not match model_type");
160  }
161  },
162  decision_forest_);
163  }
164 
192  template <typename io_t>
193  void predict(raft_proto::handle_t const& handle,
194  raft_proto::buffer<io_t>& output,
195  raft_proto::buffer<io_t> const& input,
196  infer_kind predict_type = infer_kind::default_kind,
197  std::optional<index_type> specified_chunk_size = std::nullopt)
198  {
199  std::visit(
200  [this, predict_type, &handle, &output, &input, &specified_chunk_size](
201  auto&& concrete_forest) {
202  using model_io_t = typename std::remove_reference_t<decltype(concrete_forest)>::io_type;
203  if constexpr (std::is_same_v<model_io_t, io_t>) {
204  if (output.memory_type() == memory_type() && input.memory_type() == memory_type()) {
205  concrete_forest.predict(
206  output, input, handle.get_next_usable_stream(), predict_type, specified_chunk_size);
207  } else {
208  auto constexpr static const MIN_CHUNKS_PER_PARTITION = std::size_t{64};
209  auto constexpr static const MAX_CHUNK_SIZE = std::size_t{64};
210 
211  auto row_count = input.size() / num_features();
212  auto partition_size =
214  specified_chunk_size.value_or(MAX_CHUNK_SIZE) * MIN_CHUNKS_PER_PARTITION);
215  auto partition_count = raft_proto::ceildiv(row_count, partition_size);
216  for (auto i = std::size_t{}; i < partition_count; ++i) {
217  auto stream = handle.get_next_usable_stream();
218  auto rows_in_this_partition =
219  std::min(partition_size, row_count - i * partition_size);
220  auto partition_in = raft_proto::buffer<io_t>{};
221  if (input.memory_type() != memory_type()) {
222  partition_in =
223  raft_proto::buffer<io_t>{rows_in_this_partition * num_features(), memory_type()};
224  raft_proto::copy<raft_proto::DEBUG_ENABLED>(partition_in,
225  input,
226  0,
227  i * partition_size * num_features(),
228  partition_in.size(),
229  stream);
230  } else {
231  partition_in =
232  raft_proto::buffer<io_t>{input.data() + i * partition_size * num_features(),
233  rows_in_this_partition * num_features(),
234  memory_type()};
235  }
236  auto partition_out = raft_proto::buffer<io_t>{};
237  if (output.memory_type() != memory_type()) {
238  partition_out =
239  raft_proto::buffer<io_t>{rows_in_this_partition * num_outputs(), memory_type()};
240  } else {
241  partition_out =
242  raft_proto::buffer<io_t>{output.data() + i * partition_size * num_outputs(),
243  rows_in_this_partition * num_outputs(),
244  memory_type()};
245  }
246  concrete_forest.predict(
247  partition_out, partition_in, stream, predict_type, specified_chunk_size);
248  if (output.memory_type() != memory_type()) {
249  raft_proto::copy<raft_proto::DEBUG_ENABLED>(output,
250  partition_out,
251  i * partition_size * num_outputs(),
252  0,
253  partition_out.size(),
254  stream);
255  }
256  }
257  }
258  } else {
259  throw type_error("Input type does not match model_type");
260  }
261  },
262  decision_forest_);
263  }
264 
291  template <typename io_t>
292  void predict(raft_proto::handle_t const& handle,
293  io_t* output,
294  io_t* input,
295  std::size_t num_rows,
296  raft_proto::device_type out_mem_type,
297  raft_proto::device_type in_mem_type,
298  infer_kind predict_type = infer_kind::default_kind,
299  std::optional<index_type> specified_chunk_size = std::nullopt)
300  {
301  // TODO(wphicks): Make sure buffer lands on same device as model
302  auto out_buffer = raft_proto::buffer{output, num_rows * num_outputs(), out_mem_type};
303  auto in_buffer = raft_proto::buffer{input, num_rows * num_features(), in_mem_type};
304  predict(handle, out_buffer, in_buffer, predict_type, specified_chunk_size);
305  }
306 
307  private:
308  decision_forest_variant decision_forest_;
309 };
310 
311 } // namespace fil
312 } // namespace experimental
313 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
infer_kind
Definition: infer_kind.hpp:20
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:416
row_op
Definition: postproc_ops.hpp:22
Definition: dbscan.hpp:30
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:21
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: forest_model.hpp:38
auto elem_postprocessing()
Definition: forest_model.hpp:90
auto num_features()
Definition: forest_model.hpp:46
forest_model(decision_forest_variant &&forest=decision_forest_variant{})
Definition: forest_model.hpp:40
auto is_double_precision()
Definition: forest_model.hpp:111
auto device_index()
Definition: forest_model.hpp:104
void predict(raft_proto::handle_t const &handle, raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:193
auto num_trees()
Definition: forest_model.hpp:60
auto has_vector_leaves()
Definition: forest_model.hpp:67
auto num_outputs()
Definition: forest_model.hpp:53
auto memory_type()
Definition: forest_model.hpp:97
auto row_postprocessing()
Definition: forest_model.hpp:74
void set_row_postprocessing(row_op val)
Definition: forest_model.hpp:81
void predict(raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, raft_proto::cuda_stream stream=raft_proto::cuda_stream{}, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:145
void predict(raft_proto::handle_t const &handle, io_t *output, io_t *input, std::size_t num_rows, raft_proto::device_type out_mem_type, raft_proto::device_type in_mem_type, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:292
Definition: forest.hpp:36
Definition: exceptions.hpp:52
A container which may or may not own its own data on host or device.
Definition: buffer.hpp:41
auto size() const noexcept
Definition: buffer.hpp:293
HOST DEVICE auto * data() const noexcept
Definition: buffer.hpp:294
auto memory_type() const noexcept
Definition: buffer.hpp:295
Definition: handle.hpp:47
auto get_usable_stream_count() const
Definition: handle.hpp:50
auto get_next_usable_stream() const
Definition: handle.hpp:48