cuda_device.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include <rmm/detail/error.hpp>
19 
20 #include <cuda_runtime_api.h>
21 
22 namespace rmm {
23 
34  using value_type = int;
35 
41  explicit constexpr cuda_device_id(value_type dev_id) noexcept : id_{dev_id} {}
42 
44  [[nodiscard]] constexpr value_type value() const noexcept { return id_; }
45 
46  // TODO re-add doxygen comment specifier /** for these hidden friend operators once this Breathe
47  // bug is fixed: https://github.com/breathe-doc/breathe/issues/916
49 
56  [[nodiscard]] constexpr friend bool operator==(cuda_device_id const& lhs,
57  cuda_device_id const& rhs) noexcept
58  {
59  return lhs.value() == rhs.value();
60  }
61 
69  [[nodiscard]] constexpr friend bool operator!=(cuda_device_id const& lhs,
70  cuda_device_id const& rhs) noexcept
71  {
72  return lhs.value() != rhs.value();
73  }
75  private:
76  value_type id_;
77 };
78 
87 {
88  cuda_device_id::value_type dev_id{-1};
89  RMM_ASSERT_CUDA_SUCCESS(cudaGetDevice(&dev_id));
90  return cuda_device_id{dev_id};
91 }
92 
99 {
100  cuda_device_id::value_type num_dev{-1};
101  RMM_ASSERT_CUDA_SUCCESS(cudaGetDeviceCount(&num_dev));
102  return num_dev;
103 }
104 
116  : old_device_{get_current_cuda_device()},
117  needs_reset_{dev_id.value() >= 0 && old_device_ != dev_id}
118  {
119  if (needs_reset_) { RMM_ASSERT_CUDA_SUCCESS(cudaSetDevice(dev_id.value())); }
120  }
125  {
126  if (needs_reset_) { RMM_ASSERT_CUDA_SUCCESS(cudaSetDevice(old_device_.value())); }
127  }
128 
130  cuda_set_device_raii& operator=(cuda_set_device_raii const&) = delete;
132  cuda_set_device_raii& operator=(cuda_set_device_raii&&) = delete;
133 
134  private:
135  cuda_device_id old_device_;
136  bool needs_reset_;
137 };
138  // end of group
140 } // namespace rmm
cuda_device_id get_current_cuda_device()
Returns a cuda_device_id for the current device.
Definition: cuda_device.hpp:86
int get_num_cuda_devices()
Returns the number of CUDA devices in the system.
Definition: cuda_device.hpp:98
bool operator==(cuda_stream_view lhs, cuda_stream_view rhs)
Equality comparison operator for streams.
Definition: cuda_stream_view.hpp:177
bool operator!=(cuda_stream_view lhs, cuda_stream_view rhs)
Inequality comparison operator for streams.
Definition: cuda_stream_view.hpp:189
Strong type for a CUDA device identifier.
Definition: cuda_device.hpp:33
constexpr cuda_device_id(value_type dev_id) noexcept
Construct a cuda_device_id from the specified integer value.
Definition: cuda_device.hpp:41
constexpr value_type value() const noexcept
The wrapped integer value.
Definition: cuda_device.hpp:44
int value_type
Integer type used for device identifier.
Definition: cuda_device.hpp:34
RAII class that sets the current CUDA device to the specified device on construction and restores the...
Definition: cuda_device.hpp:109
~cuda_set_device_raii() noexcept
Reactivates the previous CUDA device.
Definition: cuda_device.hpp:124
cuda_set_device_raii(cuda_device_id dev_id)
Construct a new cuda_set_device_raii object and sets the current CUDA device to dev_id
Definition: cuda_device.hpp:115