70           typename metadata_storage_t,
 
   76   auto constexpr 
static const layout = layout_v;
 
  108       root_node_indexes_{},
 
  111       categorical_storage_{},
 
  115       has_categorical_nodes_{false},
 
  165                   bool has_categorical_nodes                                 = 
false,
 
  166                   std::optional<raft_proto::buffer<io_type>>&& vector_output = std::nullopt,
 
  167                   std::optional<raft_proto::buffer<typename node_type::index_type>>&&
 
  168                     categorical_storage     = std::nullopt,
 
  176       root_node_indexes_{root_node_indexes},
 
  177       node_id_mapping_{node_id_mapping},
 
  178       vector_output_{vector_output},
 
  179       categorical_storage_{categorical_storage},
 
  182       leaf_size_{leaf_size},
 
  183       has_categorical_nodes_{has_categorical_nodes},
 
  184       row_postproc_{row_postproc},
 
  185       elem_postproc_{elem_postproc},
 
  186       average_factor_{average_factor},
 
  188       postproc_constant_{postproc_constant}
 
  190     if (nodes.memory_type() != root_node_indexes.memory_type()) {
 
  192         "Nodes and indexes of forest must both be stored on either host or device");
 
  194     if (nodes.device_index() != root_node_indexes.device_index()) {
 
  196         "Nodes and indexes of forest must both be stored on same device");
 
  198     detail::initialize_device<forest_type>(nodes.device());
 
  212     auto result = num_outputs_;
 
  261                std::optional<index_type> specified_rows_per_block_iter = std::nullopt)
 
  265         "Tried to use host I/O data with model on device or vice versa"};
 
  270     auto* vector_output_data =
 
  271       (vector_output_.has_value() ? vector_output_->data() : 
static_cast<io_type*
>(
nullptr));
 
  272     auto* categorical_storage_data =
 
  273       (categorical_storage_.has_value() ? categorical_storage_->data()
 
  275     switch (nodes_.
device().index()) {
 
  278                            get_postprocessor(predict_type),
 
  284                            has_categorical_nodes_,
 
  286                            categorical_storage_data,
 
  288                            specified_rows_per_block_iter,
 
  289                            std::get<0>(nodes_.
device()),
 
  294                            get_postprocessor(predict_type),
 
  300                            has_categorical_nodes_,
 
  302                            categorical_storage_data,
 
  304                            specified_rows_per_block_iter,
 
  305                            std::get<1>(nodes_.
device()),
 
  319   std::optional<raft_proto::buffer<io_type>> vector_output_;
 
  322   std::optional<raft_proto::buffer<categorical_storage_type>> categorical_storage_;
 
  328   bool has_categorical_nodes_ = 
false;
 
  339                        root_node_indexes_.
data(),
 
  340                        node_id_mapping_.
data(),
 
  350         row_postproc_, elem_postproc_, average_factor_, bias_, postproc_constant_};
 
  355   auto leaf_size()
 const { 
return leaf_size_; }
 
  372 template <tree_layout layout, 
bool double_precision, 
bool large_trees>
 
  385     std::variant_alternative_t<0, detail::specialization_variant>::layout,
 
  386     std::variant_alternative_t<0, detail::specialization_variant>::is_double_precision,
 
  387     std::variant_alternative_t<0, detail::specialization_variant>::has_large_trees>,
 
  389     std::variant_alternative_t<1, detail::specialization_variant>::layout,
 
  390     std::variant_alternative_t<1, detail::specialization_variant>::is_double_precision,
 
  391     std::variant_alternative_t<1, detail::specialization_variant>::has_large_trees>,
 
  393     std::variant_alternative_t<2, detail::specialization_variant>::layout,
 
  394     std::variant_alternative_t<2, detail::specialization_variant>::is_double_precision,
 
  395     std::variant_alternative_t<2, detail::specialization_variant>::has_large_trees>,
 
  397     std::variant_alternative_t<3, detail::specialization_variant>::layout,
 
  398     std::variant_alternative_t<3, detail::specialization_variant>::is_double_precision,
 
  399     std::variant_alternative_t<3, detail::specialization_variant>::has_large_trees>,
 
  401     std::variant_alternative_t<4, detail::specialization_variant>::layout,
 
  402     std::variant_alternative_t<4, detail::specialization_variant>::is_double_precision,
 
  403     std::variant_alternative_t<4, detail::specialization_variant>::has_large_trees>,
 
  405     std::variant_alternative_t<5, detail::specialization_variant>::layout,
 
  406     std::variant_alternative_t<5, detail::specialization_variant>::is_double_precision,
 
  407     std::variant_alternative_t<5, detail::specialization_variant>::has_large_trees>,
 
  409     std::variant_alternative_t<6, detail::specialization_variant>::layout,
 
  410     std::variant_alternative_t<6, detail::specialization_variant>::is_double_precision,
 
  411     std::variant_alternative_t<6, detail::specialization_variant>::has_large_trees>,
 
  413     std::variant_alternative_t<7, detail::specialization_variant>::layout,
 
  414     std::variant_alternative_t<7, detail::specialization_variant>::is_double_precision,
 
  415     std::variant_alternative_t<7, detail::specialization_variant>::has_large_trees>,
 
  417     std::variant_alternative_t<8, detail::specialization_variant>::layout,
 
  418     std::variant_alternative_t<8, detail::specialization_variant>::is_double_precision,
 
  419     std::variant_alternative_t<8, detail::specialization_variant>::has_large_trees>,
 
  421     std::variant_alternative_t<9, detail::specialization_variant>::layout,
 
  422     std::variant_alternative_t<9, detail::specialization_variant>::is_double_precision,
 
  423     std::variant_alternative_t<9, detail::specialization_variant>::has_large_trees>,
 
  425     std::variant_alternative_t<10, detail::specialization_variant>::layout,
 
  426     std::variant_alternative_t<10, detail::specialization_variant>::is_double_precision,
 
  427     std::variant_alternative_t<10, detail::specialization_variant>::has_large_trees>,
 
  429     std::variant_alternative_t<11, detail::specialization_variant>::layout,
 
  430     std::variant_alternative_t<11, detail::specialization_variant>::is_double_precision,
 
  431     std::variant_alternative_t<11, detail::specialization_variant>::has_large_trees>>;
 
  459   using small_index_t =
 
  461   auto max_local_categories = 
index_type(
sizeof(small_index_t) * 8);
 
  466   auto double_indexes_required =
 
  467     (max_num_categories > max_local_categories &&
 
  468      ((
raft_proto::ceildiv(max_num_categories, max_local_categories) + 1 * num_categorical_nodes) >
 
  472   auto double_precision = use_double_thresholds || double_indexes_required;
 
  474   using small_metadata_t =
 
  476   using small_offset_t =
 
  483   auto layout_value = 
static_cast<std::underlying_type_t<tree_layout>
>(layout);
 
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
 
void infer(forest_t const &forest, postprocessor< typename forest_t::io_type > const &postproc, typename forest_t::io_type *output, typename forest_t::io_type *input, index_type row_count, index_type col_count, index_type output_count, bool has_categorical_nodes, typename forest_t::io_type *vector_output=nullptr, typename forest_t::node_type::index_type *categorical_data=nullptr, infer_kind infer_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt, raft_proto::device_id< D > device=raft_proto::device_id< D >{}, raft_proto::cuda_stream stream=raft_proto::cuda_stream{})
Definition: infer.hpp:68
 
infer_kind
Definition: infer_kind.hpp:19
 
auto get_forest_variant_index(bool use_double_thresholds, index_type max_node_offset, index_type num_features, index_type num_categorical_nodes=index_type{}, index_type max_num_categories=index_type{}, index_type num_vector_leaves=index_type{}, tree_layout layout=preferred_tree_layout)
Definition: decision_forest.hpp:451
 
tree_layout
Definition: tree_layout.hpp:19
 
row_op
Definition: postproc_ops.hpp:21
 
element_op
Definition: postproc_ops.hpp:28
 
uint32_t index_type
Definition: index_type.hpp:20
 
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:431
 
Definition: dbscan.hpp:29
 
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:21
 
int cuda_stream
Definition: cuda_stream.hpp:25
 
Definition: decision_forest.hpp:72
 
decision_forest(raft_proto::buffer< node_type > &&nodes, raft_proto::buffer< index_type > &&root_node_indexes, raft_proto::buffer< index_type > &&node_id_mapping, index_type num_features, index_type num_outputs=index_type{2}, bool has_categorical_nodes=false, std::optional< raft_proto::buffer< io_type >> &&vector_output=std::nullopt, std::optional< raft_proto::buffer< typename node_type::index_type >> &&categorical_storage=std::nullopt, index_type leaf_size=index_type{1}, row_op row_postproc=row_op::disable, element_op elem_postproc=element_op::disable, io_type average_factor=io_type{1}, io_type bias=io_type{0}, io_type postproc_constant=io_type{1})
Definition: decision_forest.hpp:160
 
auto row_postprocessing() const
Definition: decision_forest.hpp:223
 
auto elem_postprocessing() const
Definition: decision_forest.hpp:228
 
typename forest_type::io_type io_type
Definition: decision_forest.hpp:89
 
decision_forest()
Definition: decision_forest.hpp:106
 
constexpr static auto const layout
Definition: decision_forest.hpp:76
 
threshold_t threshold_type
Definition: decision_forest.hpp:93
 
auto num_trees() const
Definition: decision_forest.hpp:204
 
void set_row_postprocessing(row_op val)
Definition: decision_forest.hpp:225
 
postprocessor< io_type > postprocessor_type
Definition: decision_forest.hpp:97
 
auto num_outputs(infer_kind inference_kind=infer_kind::default_kind) const
Definition: decision_forest.hpp:210
 
auto has_vector_leaves() const
Definition: decision_forest.hpp:206
 
typename node_type::index_type categorical_storage_type
Definition: decision_forest.hpp:101
 
auto device_index()
Definition: decision_forest.hpp:233
 
auto num_features() const
Definition: decision_forest.hpp:202
 
auto memory_type()
Definition: decision_forest.hpp:231
 
void predict(raft_proto::buffer< typename forest_type::io_type > &output, raft_proto::buffer< typename forest_type::io_type > const &input, raft_proto::cuda_stream stream=raft_proto::cuda_stream{}, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_rows_per_block_iter=std::nullopt)
Definition: decision_forest.hpp:257
 
typename forest_type::node_type node_type
Definition: decision_forest.hpp:85
 
forest< layout, threshold_t, index_t, metadata_storage_t, offset_t > forest_type
Definition: decision_forest.hpp:81
 
std::conditional_t< double_precision, double, float > threshold_type
Definition: specialization_types.hpp:47
 
std::conditional_t< large_trees, std::uint32_t, std::uint16_t > metadata_type
Definition: specialization_types.hpp:53
 
std::conditional_t< double_precision, std::uint64_t, std::uint32_t > index_type
Definition: specialization_types.hpp:51
 
std::conditional_t< large_trees, std::uint32_t, std::uint16_t > offset_type
Definition: specialization_types.hpp:55
 
Definition: forest.hpp:35
 
threshold_t io_type
Definition: forest.hpp:37
 
node< layout_v, threshold_t, index_t, metadata_storage_t, offset_t > node_type
Definition: forest.hpp:36
 
Definition: postprocessor.hpp:140
 
auto size() const noexcept
Definition: buffer.hpp:293
 
HOST DEVICE auto * data() const noexcept
Definition: buffer.hpp:294
 
auto memory_type() const noexcept
Definition: buffer.hpp:295
 
auto device_index() const noexcept
Definition: buffer.hpp:308
 
auto device() const noexcept
Definition: buffer.hpp:306
 
Definition: exceptions.hpp:49
 
Definition: exceptions.hpp:38
 
Definition: exceptions.hpp:58