device_initialization.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
8 
9 #include <variant>
10 #ifdef CUML_ENABLE_GPU
12 #endif
13 
14 namespace ML {
15 namespace fil {
16 namespace detail {
17 /* Set any required device options for optimizing FIL compute */
18 template <typename forest_t, raft_proto::device_type D>
20 {
21  device_initialization::initialize_device<forest_t>(device);
22 }
23 
24 /* Set any required device options for optimizing FIL compute */
25 template <typename forest_t>
27 {
28  std::visit(
29  [](auto&& concrete_device) {
30  device_initialization::initialize_device<forest_t>(concrete_device);
31  },
32  device);
33 }
34 } // namespace detail
35 } // namespace fil
36 } // namespace ML
void initialize_device(raft_proto::device_id< D > device)
Definition: device_initialization.hpp:19
Definition: dbscan.hpp:18
std::variant< device_id< device_type::cpu >, device_id< device_type::gpu > > device_id_variant
Definition: device_id.hpp:20
Definition: base.hpp:11