gpu.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2023-2024, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
23 
24 #include <cstddef>
25 #include <optional>
26 
27 namespace ML {
28 namespace experimental {
29 namespace fil {
30 namespace detail {
31 namespace inference {
32 
33 /* The CUDA-free header declaration of the GPU infer template */
34 template <raft_proto::device_type D,
35  bool has_categorical_nodes,
36  typename forest_t,
37  typename vector_output_t = std::nullptr_t,
38  typename categorical_data_t = std::nullptr_t>
39 std::enable_if_t<D == raft_proto::device_type::gpu, void> infer(
40  forest_t const& forest,
41  postprocessor<typename forest_t::io_type> const& postproc,
42  typename forest_t::io_type* output,
43  typename forest_t::io_type* input,
44  index_type row_count,
45  index_type col_count,
46  index_type class_count,
47  vector_output_t vector_output = nullptr,
48  categorical_data_t categorical_data = nullptr,
50  std::optional<index_type> specified_chunk_size = std::nullopt,
53 
54 } // namespace inference
55 } // namespace detail
56 } // namespace fil
57 } // namespace experimental
58 } // namespace ML
std::enable_if_t< std::disjunction_v< std::bool_constant< D==raft_proto::device_type::cpu >, std::bool_constant<!raft_proto::GPU_ENABLED > >, void > infer(forest_t const &forest, postprocessor< typename forest_t::io_type > const &postproc, typename forest_t::io_type *output, typename forest_t::io_type *input, index_type row_count, index_type col_count, index_type output_count, vector_output_t vector_output=nullptr, categorical_data_t categorical_data=nullptr, infer_kind infer_type=infer_kind::default_kind, std::optional< index_type > specified_chunk_size=std::nullopt, raft_proto::device_id< D > device=raft_proto::device_id< D >{}, raft_proto::cuda_stream=raft_proto::cuda_stream{})
Definition: cpu.hpp:88
uint32_t index_type
Definition: index_type.hpp:21
infer_kind
Definition: infer_kind.hpp:20
forest< real_t > * forest_t
Definition: fil.h:89
Definition: dbscan.hpp:30
int cuda_stream
Definition: cuda_stream.hpp:25
device_type
Definition: device_type.hpp:18
Definition: base.hpp:22