gpu.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
10 
11 #include <raft/util/cudart_utils.hpp>
12 
13 #include <cuda_runtime_api.h>
14 
15 namespace raft_proto {
16 namespace detail {
17 
19 template <>
22  : prev_device_{[]() {
23  auto result = int{};
24  raft_proto::cuda_check(cudaGetDevice(&result));
25  return result;
26  }()}
27  {
28  raft_proto::cuda_check(cudaSetDevice(device.value()));
29  }
30 
31  ~device_setter() { RAFT_CUDA_TRY_NO_THROW(cudaSetDevice(prev_device_.value())); }
32 
33  private:
34  device_id<device_type::gpu> prev_device_;
35 };
36 
37 } // namespace detail
38 } // namespace raft_proto
Definition: buffer.hpp:24
void cuda_check(error_t const &err) noexcept(!GPU_ENABLED)
Definition: cuda_check.hpp:15
device_type
Definition: device_type.hpp:7
Definition: base.hpp:11
device_setter(raft_proto::device_id< device_type::gpu > device) noexcept(false)
Definition: gpu.hpp:21
Definition: base.hpp:14