25 #include <type_traits>
29 namespace experimental {
48 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.num_features(); },
55 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.num_outputs(); },
62 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.num_trees(); },
69 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.has_vector_leaves(); },
76 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.row_postprocessing(); },
84 [&val](
auto&& concrete_forest) { concrete_forest.set_row_postprocessing(val); },
92 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.elem_postprocessing(); },
99 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.memory_type(); },
106 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.device_index(); },
114 [](
auto&& concrete_forest) {
115 return std::is_same_v<
typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
144 template <
typename io_t>
149 std::optional<index_type> specified_chunk_size = std::nullopt)
152 [
this, predict_type, &output, &input, &stream, &specified_chunk_size](
153 auto&& concrete_forest) {
154 if constexpr (std::is_same_v<
155 typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
157 concrete_forest.predict(output, input, stream, predict_type, specified_chunk_size);
159 throw type_error(
"Input type does not match model_type");
192 template <
typename io_t>
197 std::optional<index_type> specified_chunk_size = std::nullopt)
200 [
this, predict_type, &handle, &output, &input, &specified_chunk_size](
201 auto&& concrete_forest) {
202 using model_io_t =
typename std::remove_reference_t<decltype(concrete_forest)>::io_type;
203 if constexpr (std::is_same_v<model_io_t, io_t>) {
205 concrete_forest.predict(
208 auto constexpr
static const MIN_CHUNKS_PER_PARTITION = std::size_t{64};
209 auto constexpr
static const MAX_CHUNK_SIZE = std::size_t{64};
212 auto partition_size =
214 specified_chunk_size.value_or(MAX_CHUNK_SIZE) * MIN_CHUNKS_PER_PARTITION);
216 for (
auto i = std::size_t{}; i < partition_count; ++i) {
218 auto rows_in_this_partition =
219 std::min(partition_size, row_count - i * partition_size);
224 raft_proto::copy<raft_proto::DEBUG_ENABLED>(partition_in,
246 concrete_forest.predict(
247 partition_out, partition_in, stream, predict_type, specified_chunk_size);
249 raft_proto::copy<raft_proto::DEBUG_ENABLED>(output,
253 partition_out.size(),
259 throw type_error(
"Input type does not match model_type");
291 template <
typename io_t>
295 std::size_t num_rows,
299 std::optional<index_type> specified_chunk_size = std::nullopt)
304 predict(handle, out_buffer, in_buffer, predict_type, specified_chunk_size);
math_t max(math_t a, math_t b)
Definition: learning_rate.h:27
infer_kind
Definition: infer_kind.hpp:20
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:416
row_op
Definition: postproc_ops.hpp:22
Definition: dbscan.hpp:30
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:21
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: forest_model.hpp:38
auto elem_postprocessing()
Definition: forest_model.hpp:90
auto num_features()
Definition: forest_model.hpp:46
forest_model(decision_forest_variant &&forest=decision_forest_variant{})
Definition: forest_model.hpp:40
auto is_double_precision()
Definition: forest_model.hpp:111
auto device_index()
Definition: forest_model.hpp:104
void predict(raft_proto::handle_t const &handle, raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:193
auto num_trees()
Definition: forest_model.hpp:60
auto has_vector_leaves()
Definition: forest_model.hpp:67
auto num_outputs()
Definition: forest_model.hpp:53
auto memory_type()
Definition: forest_model.hpp:97
auto row_postprocessing()
Definition: forest_model.hpp:74
void set_row_postprocessing(row_op val)
Definition: forest_model.hpp:81
void predict(raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, raft_proto::cuda_stream stream=raft_proto::cuda_stream{}, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:145
void predict(raft_proto::handle_t const &handle, io_t *output, io_t *input, std::size_t num_rows, raft_proto::device_type out_mem_type, raft_proto::device_type in_mem_type, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:292
Definition: forest.hpp:36
Definition: exceptions.hpp:52
A container which may or may not own its own data on host or device.
Definition: buffer.hpp:41
auto size() const noexcept
Definition: buffer.hpp:293
HOST DEVICE auto * data() const noexcept
Definition: buffer.hpp:294
auto memory_type() const noexcept
Definition: buffer.hpp:295
Definition: handle.hpp:47
auto get_usable_stream_count() const
Definition: handle.hpp:50
auto get_next_usable_stream() const
Definition: handle.hpp:48