decision_forest.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 #include <cuml/fil/constants.hpp>
16 #include <cuml/fil/exceptions.hpp>
17 #include <cuml/fil/infer_kind.hpp>
19 #include <cuml/fil/tree_layout.hpp>
20 
21 #include <stddef.h>
22 #include <stdint.h>
23 
24 #include <algorithm>
25 #include <cstddef>
26 #include <limits>
27 #include <optional>
28 #include <variant>
29 
30 namespace ML {
31 namespace fil {
32 
56 template <tree_layout layout_v,
57  typename threshold_t,
58  typename index_t,
59  typename metadata_storage_t,
60  typename offset_t>
65  auto constexpr static const layout = layout_v;
78  using io_type = typename forest_type::io_type;
82  using threshold_type = threshold_t;
91 
96  : nodes_{},
97  root_node_indexes_{},
98  node_id_mapping_{},
99  bias_{},
100  vector_output_{},
101  categorical_storage_{},
102  num_features_{},
103  num_outputs_{},
104  leaf_size_{},
105  has_categorical_nodes_{false},
106  row_postproc_{},
107  elem_postproc_{},
108  average_factor_{},
109  postproc_constant_{}
110  {
111  }
112 
150  raft_proto::buffer<index_type>&& root_node_indexes,
151  raft_proto::buffer<index_type>&& node_id_mapping,
155  bool has_categorical_nodes = false,
156  std::optional<raft_proto::buffer<io_type>>&& vector_output = std::nullopt,
157  std::optional<raft_proto::buffer<typename node_type::index_type>>&&
158  categorical_storage = std::nullopt,
159  index_type leaf_size = index_type{1},
160  row_op row_postproc = row_op::disable,
161  element_op elem_postproc = element_op::disable,
162  io_type average_factor = io_type{1},
163  io_type postproc_constant = io_type{1})
164  : nodes_{nodes},
165  root_node_indexes_{root_node_indexes},
166  node_id_mapping_{node_id_mapping},
167  bias_{bias},
168  vector_output_{vector_output},
169  categorical_storage_{categorical_storage},
170  num_features_{num_features},
171  num_outputs_{num_outputs},
172  leaf_size_{leaf_size},
173  has_categorical_nodes_{has_categorical_nodes},
174  row_postproc_{row_postproc},
175  elem_postproc_{elem_postproc},
176  average_factor_{average_factor},
177  postproc_constant_{postproc_constant}
178  {
179  if (nodes.memory_type() != root_node_indexes.memory_type()) {
181  "Nodes and indexes of forest must both be stored on either host or device");
182  }
183  if (nodes.device_index() != root_node_indexes.device_index()) {
185  "Nodes and indexes of forest must both be stored on same device");
186  }
187  detail::initialize_device<forest_type>(nodes.device());
188  }
189 
191  auto num_features() const { return num_features_; }
193  auto num_trees() const { return root_node_indexes_.size(); }
195  auto has_vector_leaves() const { return vector_output_.has_value(); }
196 
202  auto num_outputs(infer_kind inference_kind = infer_kind::default_kind) const
203  {
204  auto result = num_outputs_;
205  if (inference_kind == infer_kind::per_tree) {
206  result = num_trees();
207  if (has_vector_leaves()) { result *= num_outputs_; }
208  } else if (inference_kind == infer_kind::leaf_id) {
209  result = num_trees();
210  }
211  return result;
212  }
213 
215  auto row_postprocessing() const { return row_postproc_; }
216  // Setter for row_postprocessing
217  void set_row_postprocessing(row_op val) { row_postproc_ = val; }
220  auto elem_postprocessing() const { return elem_postproc_; }
221 
223  auto memory_type() { return nodes_.memory_type(); }
225  auto device_index() { return nodes_.device_index(); }
226 
252  infer_kind predict_type = infer_kind::default_kind,
253  std::optional<index_type> specified_rows_per_block_iter = std::nullopt)
254  {
255  if (output.memory_type() != memory_type() || input.memory_type() != memory_type()) {
257  "Tried to use host I/O data with model on device or vice versa"};
258  }
259  if (output.device_index() != device_index() || input.device_index() != device_index()) {
260  throw raft_proto::wrong_device{"I/O data on different device than model"};
261  }
262  auto* vector_output_data =
263  (vector_output_.has_value() ? vector_output_->data() : static_cast<io_type*>(nullptr));
264  auto* categorical_storage_data =
265  (categorical_storage_.has_value() ? categorical_storage_->data()
266  : static_cast<categorical_storage_type*>(nullptr));
267  switch (nodes_.device().index()) {
268  case 0:
269  fil::detail::infer(obj(),
270  get_postprocessor(predict_type),
271  output.data(),
272  input.data(),
273  index_type(input.size() / num_features_),
274  num_features_,
275  num_outputs(predict_type),
276  has_categorical_nodes_,
277  vector_output_data,
278  categorical_storage_data,
279  predict_type,
280  specified_rows_per_block_iter,
281  std::get<0>(nodes_.device()),
282  stream);
283  break;
284  case 1:
285  fil::detail::infer(obj(),
286  get_postprocessor(predict_type),
287  output.data(),
288  input.data(),
289  index_type(input.size() / num_features_),
290  num_features_,
291  num_outputs(predict_type),
292  has_categorical_nodes_,
293  vector_output_data,
294  categorical_storage_data,
295  predict_type,
296  specified_rows_per_block_iter,
297  std::get<1>(nodes_.device()),
298  stream);
299  break;
300  }
301  }
302 
303  private:
307  raft_proto::buffer<index_type> root_node_indexes_;
309  raft_proto::buffer<index_type> node_id_mapping_;
313  std::optional<raft_proto::buffer<io_type>> vector_output_;
316  std::optional<raft_proto::buffer<categorical_storage_type>> categorical_storage_;
317 
318  // Metadata
319  index_type num_features_;
320  index_type num_outputs_;
321  index_type leaf_size_;
322  bool has_categorical_nodes_ = false;
323  // Postprocessing constants
324  row_op row_postproc_;
325  element_op elem_postproc_;
326  io_type average_factor_;
327  io_type postproc_constant_;
328 
329  auto obj() const
330  {
331  return forest_type{nodes_.data(),
332  root_node_indexes_.data(),
333  node_id_mapping_.data(),
334  bias_.data(),
335  static_cast<index_type>(root_node_indexes_.size()),
336  num_outputs_};
337  }
338 
339  auto get_postprocessor(infer_kind inference_kind = infer_kind::default_kind) const
340  {
341  auto result = postprocessor_type{};
342  if (inference_kind == infer_kind::default_kind) {
343  result =
344  postprocessor_type{row_postproc_, elem_postproc_, average_factor_, postproc_constant_};
345  }
346  return result;
347  }
348 
349  auto leaf_size() const { return leaf_size_; }
350 };
351 
352 namespace detail {
366 template <tree_layout layout, bool double_precision, bool large_trees>
368  layout,
373 
374 } // namespace detail
375 
377 using decision_forest_variant = std::variant<
379  std::variant_alternative_t<0, detail::specialization_variant>::layout,
380  std::variant_alternative_t<0, detail::specialization_variant>::is_double_precision,
381  std::variant_alternative_t<0, detail::specialization_variant>::has_large_trees>,
383  std::variant_alternative_t<1, detail::specialization_variant>::layout,
384  std::variant_alternative_t<1, detail::specialization_variant>::is_double_precision,
385  std::variant_alternative_t<1, detail::specialization_variant>::has_large_trees>,
387  std::variant_alternative_t<2, detail::specialization_variant>::layout,
388  std::variant_alternative_t<2, detail::specialization_variant>::is_double_precision,
389  std::variant_alternative_t<2, detail::specialization_variant>::has_large_trees>,
391  std::variant_alternative_t<3, detail::specialization_variant>::layout,
392  std::variant_alternative_t<3, detail::specialization_variant>::is_double_precision,
393  std::variant_alternative_t<3, detail::specialization_variant>::has_large_trees>,
395  std::variant_alternative_t<4, detail::specialization_variant>::layout,
396  std::variant_alternative_t<4, detail::specialization_variant>::is_double_precision,
397  std::variant_alternative_t<4, detail::specialization_variant>::has_large_trees>,
399  std::variant_alternative_t<5, detail::specialization_variant>::layout,
400  std::variant_alternative_t<5, detail::specialization_variant>::is_double_precision,
401  std::variant_alternative_t<5, detail::specialization_variant>::has_large_trees>,
403  std::variant_alternative_t<6, detail::specialization_variant>::layout,
404  std::variant_alternative_t<6, detail::specialization_variant>::is_double_precision,
405  std::variant_alternative_t<6, detail::specialization_variant>::has_large_trees>,
407  std::variant_alternative_t<7, detail::specialization_variant>::layout,
408  std::variant_alternative_t<7, detail::specialization_variant>::is_double_precision,
409  std::variant_alternative_t<7, detail::specialization_variant>::has_large_trees>,
411  std::variant_alternative_t<8, detail::specialization_variant>::layout,
412  std::variant_alternative_t<8, detail::specialization_variant>::is_double_precision,
413  std::variant_alternative_t<8, detail::specialization_variant>::has_large_trees>,
415  std::variant_alternative_t<9, detail::specialization_variant>::layout,
416  std::variant_alternative_t<9, detail::specialization_variant>::is_double_precision,
417  std::variant_alternative_t<9, detail::specialization_variant>::has_large_trees>,
419  std::variant_alternative_t<10, detail::specialization_variant>::layout,
420  std::variant_alternative_t<10, detail::specialization_variant>::is_double_precision,
421  std::variant_alternative_t<10, detail::specialization_variant>::has_large_trees>,
423  std::variant_alternative_t<11, detail::specialization_variant>::layout,
424  std::variant_alternative_t<11, detail::specialization_variant>::is_double_precision,
425  std::variant_alternative_t<11, detail::specialization_variant>::has_large_trees>>;
426 
445 inline auto get_forest_variant_index(bool use_double_thresholds,
446  index_type max_node_offset,
447  index_type num_features,
448  index_type num_categorical_nodes = index_type{},
449  index_type max_num_categories = index_type{},
450  index_type num_vector_leaves = index_type{},
451  tree_layout layout = preferred_tree_layout)
452 {
453  using small_index_t =
455  auto max_local_categories = index_type(sizeof(small_index_t) * 8);
456  // If the index required for pointing to categorical storage bins or vector
457  // leaf output exceeds what we can store in a uint32_t, uint64_t will be used
458  //
459  // TODO(wphicks): We are overestimating categorical storage required here
460  auto double_indexes_required =
461  (max_num_categories > max_local_categories &&
462  ((raft_proto::ceildiv(max_num_categories, max_local_categories) + 1 * num_categorical_nodes) >
464  num_vector_leaves > std::numeric_limits<small_index_t>::max();
465 
466  auto double_precision = use_double_thresholds || double_indexes_required;
467 
468  using small_metadata_t =
470  using small_offset_t =
472 
473  auto large_trees =
474  (num_features > (std::numeric_limits<small_metadata_t>::max() >> reserved_node_metadata_bits) ||
475  max_node_offset > std::numeric_limits<small_offset_t>::max());
476 
477  auto layout_value = static_cast<std::underlying_type_t<tree_layout>>(layout);
478 
479  return ((index_type{layout_value} << index_type{2}) +
480  (index_type{double_precision} << index_type{1}) + index_type{large_trees});
481 }
482 } // namespace fil
483 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
void infer(forest_t const &forest, postprocessor< typename forest_t::io_type > const &postproc, typename forest_t::io_type *output, typename forest_t::io_type *input, index_type row_count, index_type col_count, index_type output_count, bool has_categorical_nodes, typename forest_t::io_type *vector_output=nullptr, typename forest_t::node_type::index_type *categorical_data=nullptr, infer_kind infer_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt, raft_proto::device_id< D > device=raft_proto::device_id< D >{}, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: infer.hpp:57
infer_kind
Definition: infer_kind.hpp:8
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:445
tree_layout
Definition: tree_layout.hpp:8
row_op
Definition: postproc_ops.hpp:10
element_op
Definition: postproc_ops.hpp:17
uint32_t index_type
Definition: index_type.hpp:9
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:425
Definition: dbscan.hpp:18
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:10
int cuda_stream
Definition: cuda_stream.hpp:14
Definition: decision_forest.hpp:61
auto row_postprocessing() const
Definition: decision_forest.hpp:215
auto elem_postprocessing() const
Definition: decision_forest.hpp:220
typename forest_type::io_type io_type
Definition: decision_forest.hpp:78
decision_forest()
Definition: decision_forest.hpp:95
constexpr static auto const layout
Definition: decision_forest.hpp:65
threshold_t threshold_type
Definition: decision_forest.hpp:82
auto num_trees() const
Definition: decision_forest.hpp:193
void set_row_postprocessing(row_op val)
Definition: decision_forest.hpp:217
postprocessor< io_type > postprocessor_type
Definition: decision_forest.hpp:86
decision_forest(raft_proto::buffer< node_type > &&nodes, raft_proto::buffer< index_type > &&root_node_indexes, raft_proto::buffer< index_type > &&node_id_mapping, raft_proto::buffer< io_type > &&bias, index_type num_features, index_type num_outputs=index_type{2}, bool has_categorical_nodes=false, std::optional< raft_proto::buffer< io_type >> &&vector_output=std::nullopt, std::optional< raft_proto::buffer< typename node_type::index_type >> &&categorical_storage=std::nullopt, index_type leaf_size=index_type{1}, row_op row_postproc=row_op::disable, element_op elem_postproc=element_op::disable, io_type average_factor=io_type{1}, io_type postproc_constant=io_type{1})
Definition: decision_forest.hpp:149
auto num_outputs(infer_kind inference_kind=infer_kind::default_kind) const
Definition: decision_forest.hpp:202
auto has_vector_leaves() const
Definition: decision_forest.hpp:195
typename node_type::index_type categorical_storage_type
Definition: decision_forest.hpp:90
auto device_index()
Definition: decision_forest.hpp:225
auto num_features() const
Definition: decision_forest.hpp:191
auto memory_type()
Definition: decision_forest.hpp:223
void predict(raft_proto::buffer< typename forest_type::io_type > &output, raft_proto::buffer< typename forest_type::io_type > const &input, raft_proto::cuda_stream stream=raft_proto::cuda_stream{}, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_rows_per_block_iter=std::nullopt)
Definition: decision_forest.hpp:249
typename forest_type::node_type node_type
Definition: decision_forest.hpp:74
forest< layout, threshold_t, index_t, metadata_storage_t, offset_t > forest_type
Definition: decision_forest.hpp:70
std::conditional_t< double_precision, double, float > threshold_type
Definition: specialization_types.hpp:36
std::conditional_t< large_trees, std::uint32_t, std::uint16_t > metadata_type
Definition: specialization_types.hpp:42
std::conditional_t< double_precision, std::uint64_t, std::uint32_t > index_type
Definition: specialization_types.hpp:40
std::conditional_t< large_trees, std::uint32_t, std::uint16_t > offset_type
Definition: specialization_types.hpp:44
Definition: forest.hpp:24
threshold_t io_type
Definition: forest.hpp:26
node< layout_v, threshold_t, index_t, metadata_storage_t, offset_t > node_type
Definition: forest.hpp:25
Definition: postprocessor.hpp:135
auto size() const noexcept
Definition: buffer.hpp:282
HOST DEVICE auto * data() const noexcept
Definition: buffer.hpp:283
auto memory_type() const noexcept
Definition: buffer.hpp:284
auto device_index() const noexcept
Definition: buffer.hpp:297
auto device() const noexcept
Definition: buffer.hpp:295
Definition: exceptions.hpp:38
Definition: exceptions.hpp:27
Definition: exceptions.hpp:47