14 #include <cuda_runtime.h>
17 #include <type_traits>
39 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.num_features(); },
46 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.num_outputs(); },
53 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.num_trees(); },
60 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.has_vector_leaves(); },
67 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.row_postprocessing(); },
75 [&val](
auto&& concrete_forest) { concrete_forest.set_row_postprocessing(val); },
83 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.elem_postprocessing(); },
90 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.memory_type(); },
97 return std::visit([](
auto&& concrete_forest) {
return concrete_forest.device_index(); },
105 [](
auto&& concrete_forest) {
106 return std::is_same_v<
typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
135 template <
typename io_t>
140 std::optional<index_type> specified_chunk_size = std::nullopt)
143 [
this, predict_type, &output, &input, &stream, &specified_chunk_size](
144 auto&& concrete_forest) {
145 if constexpr (std::is_same_v<
146 typename std::remove_reference_t<decltype(concrete_forest)>::io_type,
148 concrete_forest.predict(output, input, stream, predict_type, specified_chunk_size);
150 throw type_error(
"Input type does not match model_type");
183 template <
typename io_t>
188 std::optional<index_type> specified_chunk_size = std::nullopt)
191 [
this, predict_type, &handle, &output, &input, &specified_chunk_size](
192 auto&& concrete_forest) {
193 using model_io_t =
typename std::remove_reference_t<decltype(concrete_forest)>::io_type;
194 if constexpr (std::is_same_v<model_io_t, io_t>) {
196 concrete_forest.predict(
199 auto constexpr
static const MIN_CHUNKS_PER_PARTITION = std::size_t{64};
200 auto constexpr
static const MAX_CHUNK_SIZE = std::size_t{64};
203 auto partition_size =
205 specified_chunk_size.value_or(MAX_CHUNK_SIZE) * MIN_CHUNKS_PER_PARTITION);
207 for (
auto i = std::size_t{}; i < partition_count; ++i) {
209 auto rows_in_this_partition =
210 std::min(partition_size, row_count - i * partition_size);
215 raft_proto::copy<raft_proto::DEBUG_ENABLED>(partition_in,
237 concrete_forest.predict(
238 partition_out, partition_in, stream, predict_type, specified_chunk_size);
240 raft_proto::copy<raft_proto::DEBUG_ENABLED>(output,
244 partition_out.size(),
250 throw type_error(
"Input type does not match model_type");
282 template <
typename io_t>
286 std::size_t num_rows,
290 std::optional<index_type> specified_chunk_size = std::nullopt)
292 int current_device_id;
298 predict(handle, out_buffer, in_buffer, predict_type, specified_chunk_size);
math_t max(math_t a, math_t b)
Definition: learning_rate.h:16
infer_kind
Definition: infer_kind.hpp:8
row_op
Definition: postproc_ops.hpp:10
std::variant< detail::preset_decision_forest< std::variant_alternative_t< 0, detail::specialization_variant >::layout, std::variant_alternative_t< 0, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 0, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 1, detail::specialization_variant >::layout, std::variant_alternative_t< 1, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 1, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 2, detail::specialization_variant >::layout, std::variant_alternative_t< 2, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 2, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 3, detail::specialization_variant >::layout, std::variant_alternative_t< 3, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 3, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 4, detail::specialization_variant >::layout, std::variant_alternative_t< 4, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 4, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 5, detail::specialization_variant >::layout, std::variant_alternative_t< 5, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 5, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 6, detail::specialization_variant >::layout, std::variant_alternative_t< 6, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 6, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 7, detail::specialization_variant >::layout, std::variant_alternative_t< 7, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 7, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 8, detail::specialization_variant >::layout, std::variant_alternative_t< 8, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 8, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 9, detail::specialization_variant >::layout, std::variant_alternative_t< 9, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 9, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 10, detail::specialization_variant >::layout, std::variant_alternative_t< 10, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 10, detail::specialization_variant >::has_large_trees >, detail::preset_decision_forest< std::variant_alternative_t< 11, detail::specialization_variant >::layout, std::variant_alternative_t< 11, detail::specialization_variant >::is_double_precision, std::variant_alternative_t< 11, detail::specialization_variant >::has_large_trees > > decision_forest_variant
Definition: decision_forest.hpp:425
Definition: dbscan.hpp:18
HOST DEVICE constexpr auto ceildiv(T dividend, U divisor)
Definition: ceildiv.hpp:10
int cuda_stream
Definition: cuda_stream.hpp:14
void cuda_check(error_t const &err) noexcept(!GPU_ENABLED)
Definition: cuda_check.hpp:15
device_type
Definition: device_type.hpp:7
Definition: forest_model.hpp:29
auto row_postprocessing()
Definition: forest_model.hpp:65
auto num_features()
Definition: forest_model.hpp:37
void predict(raft_proto::handle_t const &handle, io_t *output, io_t *input, std::size_t num_rows, raft_proto::device_type out_mem_type, raft_proto::device_type in_mem_type, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:283
void predict(raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, raft_proto::cuda_stream stream=raft_proto::cuda_stream{}, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:136
auto num_trees()
Definition: forest_model.hpp:51
auto num_outputs()
Definition: forest_model.hpp:44
forest_model(decision_forest_variant &&forest=decision_forest_variant{})
Definition: forest_model.hpp:31
auto elem_postprocessing()
Definition: forest_model.hpp:81
void predict(raft_proto::handle_t const &handle, raft_proto::buffer< io_t > &output, raft_proto::buffer< io_t > const &input, infer_kind predict_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt)
Definition: forest_model.hpp:184
auto memory_type()
Definition: forest_model.hpp:88
auto has_vector_leaves()
Definition: forest_model.hpp:58
void set_row_postprocessing(row_op val)
Definition: forest_model.hpp:72
auto device_index()
Definition: forest_model.hpp:95
auto is_double_precision()
Definition: forest_model.hpp:102
Definition: forest.hpp:24
Definition: exceptions.hpp:40
A container which may or may not own its own data on host or device.
Definition: buffer.hpp:30
auto size() const noexcept
Definition: buffer.hpp:282
HOST DEVICE auto * data() const noexcept
Definition: buffer.hpp:283
auto memory_type() const noexcept
Definition: buffer.hpp:284
Definition: handle.hpp:36
auto get_usable_stream_count() const
Definition: handle.hpp:39
auto get_next_usable_stream() const
Definition: handle.hpp:37