25 #include <type_traits> 
   47     return std::visit([](
auto&& concrete_forest) { 
return concrete_forest.num_features(); },
 
   54     return std::visit([](
auto&& concrete_forest) { 
return concrete_forest.num_outputs(); },
 
   61     return std::visit([](
auto&& concrete_forest) { 
return concrete_forest.num_trees(); },
 
   68     return std::visit([](
auto&& concrete_forest) { 
return concrete_forest.has_vector_leaves(); },
 
   75     return std::visit([](
auto&& concrete_forest) { 
return concrete_forest.row_postprocessing(); },
 
   83       [&val](
auto&& concrete_forest) { concrete_forest.set_row_postprocessing(val); },
 
   91     return std::visit([](
auto&& concrete_forest) { 
return concrete_forest.elem_postprocessing(); },
 
   98     return std::visit([](
auto&& concrete_forest) { 
return concrete_forest.memory_type(); },
 
  105     return std::visit([](
auto&& concrete_forest) { 
return concrete_forest.device_index(); },
 
  113       [](
auto&& concrete_forest) {
 
  114         return std::is_same_v<
typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
 
  143   template <
typename io_t>
 
  148                std::optional<index_type> specified_chunk_size = std::nullopt)
 
  151       [
this, predict_type, &output, &input, &stream, &specified_chunk_size](
 
  152         auto&& concrete_forest) {
 
  153         if constexpr (std::is_same_v<
 
  154                         typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
 
  156           concrete_forest.predict(output, input, stream, predict_type, specified_chunk_size);
 
  158           throw type_error(
"Input type does not match model_type");
 
  191   template <
typename io_t>
 
  196                std::optional<index_type> specified_chunk_size = std::nullopt)
 
  199       [
this, predict_type, &handle, &output, &input, &specified_chunk_size](
 
  200         auto&& concrete_forest) {
 
  201         using model_io_t = 
typename std::remove_reference_t<decltype(concrete_forest)>::io_type;
 
  202         if constexpr (std::is_same_v<model_io_t, io_t>) {
 
  204             concrete_forest.predict(
 
  207             auto constexpr 
static const MIN_CHUNKS_PER_PARTITION = std::size_t{64};
 
  208             auto constexpr 
static const MAX_CHUNK_SIZE           = std::size_t{64};
 
  211             auto partition_size =
 
  213                        specified_chunk_size.value_or(MAX_CHUNK_SIZE) * MIN_CHUNKS_PER_PARTITION);
 
  215             for (
auto i = std::size_t{}; i < partition_count; ++i) {
 
  217               auto rows_in_this_partition =
 
  218                 std::min(partition_size, row_count - i * partition_size);
 
  223                 raft_proto::copy<raft_proto::DEBUG_ENABLED>(partition_in,
 
  245               concrete_forest.predict(
 
  246                 partition_out, partition_in, stream, predict_type, specified_chunk_size);
 
  248                 raft_proto::copy<raft_proto::DEBUG_ENABLED>(output,
 
  252                                                             partition_out.size(),
 
  258           throw type_error(
"Input type does not match model_type");
 
  290   template <
typename io_t>
 
  294                std::size_t num_rows,
 
  298                std::optional<index_type> specified_chunk_size = std::nullopt)
 
  300     int current_device_id;
 
  306     predict(handle, out_buffer, in_buffer, predict_type, specified_chunk_size);
 
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
 
infer_kind
Definition: infer_kind.hpp:19
 
row_op
Definition: postproc_ops.hpp:21
 
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:431
 
Definition: dbscan.hpp:29
 
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:21
 
int cuda_stream
Definition: cuda_stream.hpp:25
 
void cuda_check(error_t const &err) noexcept(!GPU_ENABLED)
Definition: cuda_check.hpp:26
 
device_type
Definition: device_type.hpp:18
 
Definition: forest_model.hpp:37
 
auto row_postprocessing()
Definition: forest_model.hpp:73
 
auto num_features()
Definition: forest_model.hpp:45
 
void predict(raft_proto::handle_t const &handle, io_t *output, io_t *input, std::size_t num_rows, raft_proto::device_type out_mem_type, raft_proto::device_type in_mem_type, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:291
 
void predict(raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, raft_proto::cuda_stream stream=raft_proto::cuda_stream{}, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:144
 
auto num_trees()
Definition: forest_model.hpp:59
 
auto num_outputs()
Definition: forest_model.hpp:52
 
forest_model(decision_forest_variant &&forest=decision_forest_variant{})
Definition: forest_model.hpp:39
 
auto elem_postprocessing()
Definition: forest_model.hpp:89
 
void predict(raft_proto::handle_t const &handle, raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:192
 
auto memory_type()
Definition: forest_model.hpp:96
 
auto has_vector_leaves()
Definition: forest_model.hpp:66
 
void set_row_postprocessing(row_op val)
Definition: forest_model.hpp:80
 
auto device_index()
Definition: forest_model.hpp:103
 
auto is_double_precision()
Definition: forest_model.hpp:110
 
Definition: forest.hpp:35
 
Definition: exceptions.hpp:51
 
A container which may or may not own its own data on host or device.
Definition: buffer.hpp:41
 
auto size() const noexcept
Definition: buffer.hpp:293
 
HOST DEVICE auto * data() const noexcept
Definition: buffer.hpp:294
 
auto memory_type() const noexcept
Definition: buffer.hpp:295
 
Definition: handle.hpp:47
 
auto get_usable_stream_count() const
Definition: handle.hpp:50
 
auto get_next_usable_stream() const
Definition: handle.hpp:48