cuda_stream_view.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2021, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <rmm/detail/error.hpp>
20 
21 #include <cuda_runtime_api.h>
22 
23 #include <cuda/stream_ref>
24 
25 #include <atomic>
26 #include <cstddef>
27 #include <cstdint>
28 
29 namespace rmm {
42  public:
43  constexpr cuda_stream_view() = default;
44  ~cuda_stream_view() = default;
45  constexpr cuda_stream_view(cuda_stream_view const&) = default;
46  constexpr cuda_stream_view(cuda_stream_view&&) = default;
48  default;
50  default;
51 
52  // Disable construction from literal 0
53  constexpr cuda_stream_view(int) = delete; //< Prevent cast from 0
54  constexpr cuda_stream_view(std::nullptr_t) = delete; //< Prevent cast from nullptr
55 
61  constexpr cuda_stream_view(cudaStream_t stream) noexcept : stream_{stream} {}
62 
68  constexpr cuda_stream_view(cuda::stream_ref stream) noexcept : stream_{stream.get()} {}
69 
75  [[nodiscard]] constexpr cudaStream_t value() const noexcept { return stream_; }
76 
82  constexpr operator cudaStream_t() const noexcept { return value(); }
83 
89  constexpr operator cuda::stream_ref() const noexcept { return value(); }
90 
94  [[nodiscard]] inline bool is_per_thread_default() const noexcept;
95 
99  [[nodiscard]] inline bool is_default() const noexcept;
100 
108  void synchronize() const { RMM_CUDA_TRY(cudaStreamSynchronize(stream_)); }
109 
115  void synchronize_no_throw() const noexcept
116  {
117  RMM_ASSERT_CUDA_SUCCESS(cudaStreamSynchronize(stream_));
118  }
119 
120  private:
121  cudaStream_t stream_{};
122 };
123 
127 static constexpr cuda_stream_view cuda_stream_default{};
128 
133 static const cuda_stream_view cuda_stream_legacy{
134  cudaStreamLegacy // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
135 };
136 
140 static const cuda_stream_view cuda_stream_per_thread{
141  cudaStreamPerThread // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
142 };
143 
144 // Need to avoid putting is_per_thread_default and is_default into the group twice. // end of group
146 
147 [[nodiscard]] inline bool cuda_stream_view::is_per_thread_default() const noexcept
148 {
149 #ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
150  return value() == cuda_stream_per_thread || value() == nullptr;
151 #else
152  return value() == cuda_stream_per_thread;
153 #endif
154 }
155 
156 [[nodiscard]] inline bool cuda_stream_view::is_default() const noexcept
157 {
158 #ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
159  return value() == cuda_stream_legacy;
160 #else
161  return value() == cuda_stream_legacy || value() == nullptr;
162 #endif
163 }
164 
178 {
179  return lhs.value() == rhs.value();
180 }
181 
189 inline bool operator!=(cuda_stream_view lhs, cuda_stream_view rhs) { return not(lhs == rhs); }
190 
198 inline std::ostream& operator<<(std::ostream& os, cuda_stream_view stream)
199 {
200  os << stream.value();
201  return os;
202 }
203  // end of group
205 } // namespace rmm
Strongly-typed non-owning wrapper for CUDA streams with default constructor.
Definition: cuda_stream_view.hpp:41
constexpr cudaStream_t value() const noexcept
Get the wrapped stream.
Definition: cuda_stream_view.hpp:75
void synchronize() const
Synchronize the viewed CUDA stream.
Definition: cuda_stream_view.hpp:108
constexpr cuda_stream_view & operator=(cuda_stream_view const &)=default
Default copy assignment operator.
constexpr cuda_stream_view & operator=(cuda_stream_view &&)=default
Default move assignment operator.
constexpr cuda_stream_view(cudaStream_t stream) noexcept
Constructor from a cudaStream_t.
Definition: cuda_stream_view.hpp:61
constexpr cuda_stream_view(cuda::stream_ref stream) noexcept
Implicit conversion from stream_ref.
Definition: cuda_stream_view.hpp:68
bool is_per_thread_default() const noexcept
true if the wrapped stream is the CUDA per-thread default stream
Definition: cuda_stream_view.hpp:147
constexpr cuda_stream_view(cuda_stream_view &&)=default
Default move constructor.
void synchronize_no_throw() const noexcept
Synchronize the viewed CUDA stream. Does not throw if there is an error.
Definition: cuda_stream_view.hpp:115
bool is_default() const noexcept
true if the wrapped stream is explicitly the CUDA legacy default stream
Definition: cuda_stream_view.hpp:156
constexpr cuda_stream_view(cuda_stream_view const &)=default
Default copy constructor.
bool operator==(cuda_stream_view lhs, cuda_stream_view rhs)
Equality comparison operator for streams.
Definition: cuda_stream_view.hpp:177
std::ostream & operator<<(std::ostream &os, cuda_stream_view stream)
Output stream operator for printing / logging streams.
Definition: cuda_stream_view.hpp:198
bool operator!=(cuda_stream_view lhs, cuda_stream_view rhs)
Inequality comparison operator for streams.
Definition: cuda_stream_view.hpp:189