decision_forest.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 #include <cuml/fil/constants.hpp>
27 #include <cuml/fil/exceptions.hpp>
28 #include <cuml/fil/infer_kind.hpp>
30 #include <cuml/fil/tree_layout.hpp>
31 
32 #include <stddef.h>
33 #include <stdint.h>
34 
35 #include <algorithm>
36 #include <cstddef>
37 #include <limits>
38 #include <optional>
39 #include <variant>
40 
41 namespace ML {
42 namespace fil {
43 
67 template <tree_layout layout_v,
68  typename threshold_t,
69  typename index_t,
70  typename metadata_storage_t,
71  typename offset_t>
76  auto constexpr static const layout = layout_v;
89  using io_type = typename forest_type::io_type;
93  using threshold_type = threshold_t;
102 
107  : nodes_{},
108  root_node_indexes_{},
109  node_id_mapping_{},
110  vector_output_{},
111  categorical_storage_{},
112  num_features_{},
113  num_outputs_{},
114  leaf_size_{},
115  has_categorical_nodes_{false},
116  row_postproc_{},
117  elem_postproc_{},
118  average_factor_{},
119  bias_{},
120  postproc_constant_{}
121  {
122  }
123 
161  raft_proto::buffer<index_type>&& root_node_indexes,
162  raft_proto::buffer<index_type>&& node_id_mapping,
165  bool has_categorical_nodes = false,
166  std::optional<raft_proto::buffer<io_type>>&& vector_output = std::nullopt,
167  std::optional<raft_proto::buffer<typename node_type::index_type>>&&
168  categorical_storage = std::nullopt,
169  index_type leaf_size = index_type{1},
170  row_op row_postproc = row_op::disable,
171  element_op elem_postproc = element_op::disable,
172  io_type average_factor = io_type{1},
173  io_type bias = io_type{0},
174  io_type postproc_constant = io_type{1})
175  : nodes_{nodes},
176  root_node_indexes_{root_node_indexes},
177  node_id_mapping_{node_id_mapping},
178  vector_output_{vector_output},
179  categorical_storage_{categorical_storage},
180  num_features_{num_features},
181  num_outputs_{num_outputs},
182  leaf_size_{leaf_size},
183  has_categorical_nodes_{has_categorical_nodes},
184  row_postproc_{row_postproc},
185  elem_postproc_{elem_postproc},
186  average_factor_{average_factor},
187  bias_{bias},
188  postproc_constant_{postproc_constant}
189  {
190  if (nodes.memory_type() != root_node_indexes.memory_type()) {
192  "Nodes and indexes of forest must both be stored on either host or device");
193  }
194  if (nodes.device_index() != root_node_indexes.device_index()) {
196  "Nodes and indexes of forest must both be stored on same device");
197  }
198  detail::initialize_device<forest_type>(nodes.device());
199  }
200 
202  auto num_features() const { return num_features_; }
204  auto num_trees() const { return root_node_indexes_.size(); }
206  auto has_vector_leaves() const { return vector_output_.has_value(); }
207 
210  auto num_outputs(infer_kind inference_kind = infer_kind::default_kind) const
211  {
212  auto result = num_outputs_;
213  if (inference_kind == infer_kind::per_tree) {
214  result = num_trees();
215  if (has_vector_leaves()) { result *= num_outputs_; }
216  } else if (inference_kind == infer_kind::leaf_id) {
217  result = num_trees();
218  }
219  return result;
220  }
221 
223  auto row_postprocessing() const { return row_postproc_; }
224  // Setter for row_postprocessing
225  void set_row_postprocessing(row_op val) { row_postproc_ = val; }
228  auto elem_postprocessing() const { return elem_postproc_; }
229 
231  auto memory_type() { return nodes_.memory_type(); }
233  auto device_index() { return nodes_.device_index(); }
234 
260  infer_kind predict_type = infer_kind::default_kind,
261  std::optional<index_type> specified_rows_per_block_iter = std::nullopt)
262  {
263  if (output.memory_type() != memory_type() || input.memory_type() != memory_type()) {
265  "Tried to use host I/O data with model on device or vice versa"};
266  }
267  if (output.device_index() != device_index() || input.device_index() != device_index()) {
268  throw raft_proto::wrong_device{"I/O data on different device than model"};
269  }
270  auto* vector_output_data =
271  (vector_output_.has_value() ? vector_output_->data() : static_cast<io_type*>(nullptr));
272  auto* categorical_storage_data =
273  (categorical_storage_.has_value() ? categorical_storage_->data()
274  : static_cast<categorical_storage_type*>(nullptr));
275  switch (nodes_.device().index()) {
276  case 0:
277  fil::detail::infer(obj(),
278  get_postprocessor(predict_type),
279  output.data(),
280  input.data(),
281  index_type(input.size() / num_features_),
282  num_features_,
283  num_outputs(predict_type),
284  has_categorical_nodes_,
285  vector_output_data,
286  categorical_storage_data,
287  predict_type,
288  specified_rows_per_block_iter,
289  std::get<0>(nodes_.device()),
290  stream);
291  break;
292  case 1:
293  fil::detail::infer(obj(),
294  get_postprocessor(predict_type),
295  output.data(),
296  input.data(),
297  index_type(input.size() / num_features_),
298  num_features_,
299  num_outputs(predict_type),
300  has_categorical_nodes_,
301  vector_output_data,
302  categorical_storage_data,
303  predict_type,
304  specified_rows_per_block_iter,
305  std::get<1>(nodes_.device()),
306  stream);
307  break;
308  }
309  }
310 
311  private:
315  raft_proto::buffer<index_type> root_node_indexes_;
317  raft_proto::buffer<index_type> node_id_mapping_;
319  std::optional<raft_proto::buffer<io_type>> vector_output_;
322  std::optional<raft_proto::buffer<categorical_storage_type>> categorical_storage_;
323 
324  // Metadata
325  index_type num_features_;
326  index_type num_outputs_;
327  index_type leaf_size_;
328  bool has_categorical_nodes_ = false;
329  // Postprocessing constants
330  row_op row_postproc_;
331  element_op elem_postproc_;
332  io_type average_factor_;
333  io_type bias_;
334  io_type postproc_constant_;
335 
336  auto obj() const
337  {
338  return forest_type{nodes_.data(),
339  root_node_indexes_.data(),
340  node_id_mapping_.data(),
341  static_cast<index_type>(root_node_indexes_.size()),
342  num_outputs_};
343  }
344 
345  auto get_postprocessor(infer_kind inference_kind = infer_kind::default_kind) const
346  {
347  auto result = postprocessor_type{};
348  if (inference_kind == infer_kind::default_kind) {
349  result = postprocessor_type{
350  row_postproc_, elem_postproc_, average_factor_, bias_, postproc_constant_};
351  }
352  return result;
353  }
354 
355  auto leaf_size() const { return leaf_size_; }
356 };
357 
358 namespace detail {
372 template <tree_layout layout, bool double_precision, bool large_trees>
374  layout,
379 
380 } // namespace detail
381 
383 using decision_forest_variant = std::variant<
385  std::variant_alternative_t<0, detail::specialization_variant>::layout,
386  std::variant_alternative_t<0, detail::specialization_variant>::is_double_precision,
387  std::variant_alternative_t<0, detail::specialization_variant>::has_large_trees>,
389  std::variant_alternative_t<1, detail::specialization_variant>::layout,
390  std::variant_alternative_t<1, detail::specialization_variant>::is_double_precision,
391  std::variant_alternative_t<1, detail::specialization_variant>::has_large_trees>,
393  std::variant_alternative_t<2, detail::specialization_variant>::layout,
394  std::variant_alternative_t<2, detail::specialization_variant>::is_double_precision,
395  std::variant_alternative_t<2, detail::specialization_variant>::has_large_trees>,
397  std::variant_alternative_t<3, detail::specialization_variant>::layout,
398  std::variant_alternative_t<3, detail::specialization_variant>::is_double_precision,
399  std::variant_alternative_t<3, detail::specialization_variant>::has_large_trees>,
401  std::variant_alternative_t<4, detail::specialization_variant>::layout,
402  std::variant_alternative_t<4, detail::specialization_variant>::is_double_precision,
403  std::variant_alternative_t<4, detail::specialization_variant>::has_large_trees>,
405  std::variant_alternative_t<5, detail::specialization_variant>::layout,
406  std::variant_alternative_t<5, detail::specialization_variant>::is_double_precision,
407  std::variant_alternative_t<5, detail::specialization_variant>::has_large_trees>,
409  std::variant_alternative_t<6, detail::specialization_variant>::layout,
410  std::variant_alternative_t<6, detail::specialization_variant>::is_double_precision,
411  std::variant_alternative_t<6, detail::specialization_variant>::has_large_trees>,
413  std::variant_alternative_t<7, detail::specialization_variant>::layout,
414  std::variant_alternative_t<7, detail::specialization_variant>::is_double_precision,
415  std::variant_alternative_t<7, detail::specialization_variant>::has_large_trees>,
417  std::variant_alternative_t<8, detail::specialization_variant>::layout,
418  std::variant_alternative_t<8, detail::specialization_variant>::is_double_precision,
419  std::variant_alternative_t<8, detail::specialization_variant>::has_large_trees>,
421  std::variant_alternative_t<9, detail::specialization_variant>::layout,
422  std::variant_alternative_t<9, detail::specialization_variant>::is_double_precision,
423  std::variant_alternative_t<9, detail::specialization_variant>::has_large_trees>,
425  std::variant_alternative_t<10, detail::specialization_variant>::layout,
426  std::variant_alternative_t<10, detail::specialization_variant>::is_double_precision,
427  std::variant_alternative_t<10, detail::specialization_variant>::has_large_trees>,
429  std::variant_alternative_t<11, detail::specialization_variant>::layout,
430  std::variant_alternative_t<11, detail::specialization_variant>::is_double_precision,
431  std::variant_alternative_t<11, detail::specialization_variant>::has_large_trees>>;
432 
451 inline auto get_forest_variant_index(bool use_double_thresholds,
452  index_type max_node_offset,
453  index_type num_features,
454  index_type num_categorical_nodes = index_type{},
455  index_type max_num_categories = index_type{},
456  index_type num_vector_leaves = index_type{},
457  tree_layout layout = preferred_tree_layout)
458 {
459  using small_index_t =
461  auto max_local_categories = index_type(sizeof(small_index_t) * 8);
462  // If the index required for pointing to categorical storage bins or vector
463  // leaf output exceeds what we can store in a uint32_t, uint64_t will be used
464  //
465  // TODO(wphicks): We are overestimating categorical storage required here
466  auto double_indexes_required =
467  (max_num_categories > max_local_categories &&
468  ((raft_proto::ceildiv(max_num_categories, max_local_categories) + 1 * num_categorical_nodes) >
470  num_vector_leaves > std::numeric_limits<small_index_t>::max();
471 
472  auto double_precision = use_double_thresholds || double_indexes_required;
473 
474  using small_metadata_t =
476  using small_offset_t =
478 
479  auto large_trees =
480  (num_features > (std::numeric_limits<small_metadata_t>::max() >> reserved_node_metadata_bits) ||
481  max_node_offset > std::numeric_limits<small_offset_t>::max());
482 
483  auto layout_value = static_cast<std::underlying_type_t<tree_layout>>(layout);
484 
485  return ((index_type{layout_value} << index_type{2}) +
486  (index_type{double_precision} << index_type{1}) + index_type{large_trees});
487 }
488 } // namespace fil
489 } // namespace ML
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
void infer(forest_t const &forest, postprocessor< typename forest_t::io_type > const &postproc, typename forest_t::io_type *output, typename forest_t::io_type *input, index_type row_count, index_type col_count, index_type output_count, bool has_categorical_nodes, typename forest_t::io_type *vector_output=nullptr, typename forest_t::node_type::index_type *categorical_data=nullptr, infer_kind infer_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt, raft_proto::device_id< D > device=raft_proto::device_id< D >{}, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: infer.hpp:68
infer_kind
Definition: infer_kind.hpp:19
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:451
tree_layout
Definition: tree_layout.hpp:19
row_op
Definition: postproc_ops.hpp:21
element_op
Definition: postproc_ops.hpp:28
uint32_t index_type
Definition: index_type.hpp:20
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:431
Definition: dbscan.hpp:29
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:21
int cuda_stream
Definition: cuda_stream.hpp:25
Definition: decision_forest.hpp:72
decision_forest(raft_proto::buffer< node_type > &&nodes, raft_proto::buffer< index_type > &&root_node_indexes, raft_proto::buffer< index_type > &&node_id_mapping, index_type num_features, index_type num_outputs=index_type{2}, bool has_categorical_nodes=false, std::optional< raft_proto::buffer< io_type >> &&vector_output=std::nullopt, std::optional< raft_proto::buffer< typename node_type::index_type >> &&categorical_storage=std::nullopt, index_type leaf_size=index_type{1}, row_op row_postproc=row_op::disable, element_op elem_postproc=element_op::disable, io_type average_factor=io_type{1}, io_type bias=io_type{0}, io_type postproc_constant=io_type{1})
Definition: decision_forest.hpp:160
auto row_postprocessing() const
Definition: decision_forest.hpp:223
auto elem_postprocessing() const
Definition: decision_forest.hpp:228
typename forest_type::io_type io_type
Definition: decision_forest.hpp:89
decision_forest()
Definition: decision_forest.hpp:106
constexpr static auto const layout
Definition: decision_forest.hpp:76
threshold_t threshold_type
Definition: decision_forest.hpp:93
auto num_trees() const
Definition: decision_forest.hpp:204
void set_row_postprocessing(row_op val)
Definition: decision_forest.hpp:225
postprocessor< io_type > postprocessor_type
Definition: decision_forest.hpp:97
auto num_outputs(infer_kind inference_kind=infer_kind::default_kind) const
Definition: decision_forest.hpp:210
auto has_vector_leaves() const
Definition: decision_forest.hpp:206
typename node_type::index_type categorical_storage_type
Definition: decision_forest.hpp:101
auto device_index()
Definition: decision_forest.hpp:233
auto num_features() const
Definition: decision_forest.hpp:202
auto memory_type()
Definition: decision_forest.hpp:231
void predict(raft_proto::buffer< typename forest_type::io_type > &output, raft_proto::buffer< typename forest_type::io_type > const &input, raft_proto::cuda_stream stream=raft_proto::cuda_stream{}, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_rows_per_block_iter=std::nullopt)
Definition: decision_forest.hpp:257
typename forest_type::node_type node_type
Definition: decision_forest.hpp:85
forest< layout, threshold_t, index_t, metadata_storage_t, offset_t > forest_type
Definition: decision_forest.hpp:81
std::conditional_t< double_precision, double, float > threshold_type
Definition: specialization_types.hpp:47
std::conditional_t< large_trees, std::uint32_t, std::uint16_t > metadata_type
Definition: specialization_types.hpp:53
std::conditional_t< double_precision, std::uint64_t, std::uint32_t > index_type
Definition: specialization_types.hpp:51
std::conditional_t< large_trees, std::uint32_t, std::uint16_t > offset_type
Definition: specialization_types.hpp:55
Definition: forest.hpp:35
threshold_t io_type
Definition: forest.hpp:37
node< layout_v, threshold_t, index_t, metadata_storage_t, offset_t > node_type
Definition: forest.hpp:36
Definition: postprocessor.hpp:140
auto size() const noexcept
Definition: buffer.hpp:293
HOST DEVICE auto * data() const noexcept
Definition: buffer.hpp:294
auto memory_type() const noexcept
Definition: buffer.hpp:295
auto device_index() const noexcept
Definition: buffer.hpp:308
auto device() const noexcept
Definition: buffer.hpp:306
Definition: exceptions.hpp:49
Definition: exceptions.hpp:38
Definition: exceptions.hpp:58