cuda_stream_view.hpp
1 /*
2  * Copyright (c) 2020-2021, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <rmm/detail/error.hpp>
20 
21 #include <cuda_runtime_api.h>
22 
23 #include <atomic>
24 #include <cstddef>
25 #include <cstdint>
26 
27 namespace rmm {
28 
35  public:
36  constexpr cuda_stream_view() = default;
37  constexpr cuda_stream_view(cuda_stream_view const&) = default;
38  constexpr cuda_stream_view(cuda_stream_view&&) = default;
39  constexpr cuda_stream_view& operator=(cuda_stream_view const&) = default;
40  constexpr cuda_stream_view& operator=(cuda_stream_view&&) = default;
41  ~cuda_stream_view() = default;
42 
43  // Disable construction from literal 0
44  constexpr cuda_stream_view(int) = delete; //< Prevent cast from 0
45  constexpr cuda_stream_view(std::nullptr_t) = delete; //< Prevent cast from nullptr
46 
50  constexpr cuda_stream_view(cudaStream_t stream) noexcept : stream_{stream} {}
51 
57  [[nodiscard]] constexpr cudaStream_t value() const noexcept { return stream_; }
58 
62  constexpr operator cudaStream_t() const noexcept { return value(); }
63 
67  [[nodiscard]] inline bool is_per_thread_default() const noexcept;
68 
72  [[nodiscard]] inline bool is_default() const noexcept;
73 
81  void synchronize() const { RMM_CUDA_TRY(cudaStreamSynchronize(stream_)); }
82 
88  void synchronize_no_throw() const noexcept
89  {
90  RMM_ASSERT_CUDA_SUCCESS(cudaStreamSynchronize(stream_));
91  }
92 
93  private:
94  cudaStream_t stream_{};
95 };
96 
100 static constexpr cuda_stream_view cuda_stream_default{};
101 
106 static const cuda_stream_view cuda_stream_legacy{
107  cudaStreamLegacy // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
108 };
109 
113 static const cuda_stream_view cuda_stream_per_thread{
114  cudaStreamPerThread // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
115 };
116 
117 [[nodiscard]] inline bool cuda_stream_view::is_per_thread_default() const noexcept
118 {
119 #ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
120  return value() == cuda_stream_per_thread || value() == nullptr;
121 #else
122  return value() == cuda_stream_per_thread;
123 #endif
124 }
125 
129 [[nodiscard]] inline bool cuda_stream_view::is_default() const noexcept
130 {
131 #ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
132  return value() == cuda_stream_legacy;
133 #else
134  return value() == cuda_stream_legacy || value() == nullptr;
135 #endif
136 }
137 
145 inline bool operator==(cuda_stream_view lhs, cuda_stream_view rhs)
146 {
147  return lhs.value() == rhs.value();
148 }
149 
157 inline bool operator!=(cuda_stream_view lhs, cuda_stream_view rhs) { return not(lhs == rhs); }
158 
166 inline std::ostream& operator<<(std::ostream& os, cuda_stream_view stream)
167 {
168  os << stream.value();
169  return os;
170 }
171 
172 } // namespace rmm
rmm::cuda_stream_view::is_default
bool is_default() const noexcept
Return true if the wrapped stream is explicitly the CUDA legacy default stream.
Definition: cuda_stream_view.hpp:129
rmm::cuda_stream_view
Strongly-typed non-owning wrapper for CUDA streams with default constructor.
Definition: cuda_stream_view.hpp:34
rmm::cuda_stream_view::cuda_stream_view
constexpr cuda_stream_view(cudaStream_t stream) noexcept
Implicit conversion from cudaStream_t.
Definition: cuda_stream_view.hpp:50
rmm::cuda_stream_view::synchronize
void synchronize() const
Synchronize the viewed CUDA stream.
Definition: cuda_stream_view.hpp:81
rmm::cuda_stream_view::value
constexpr cudaStream_t value() const noexcept
Get the wrapped stream.
Definition: cuda_stream_view.hpp:57
rmm::cuda_stream_view::synchronize_no_throw
void synchronize_no_throw() const noexcept
Synchronize the viewed CUDA stream. Does not throw if there is an error.
Definition: cuda_stream_view.hpp:88
rmm::cuda_stream_view::is_per_thread_default
bool is_per_thread_default() const noexcept
Return true if the wrapped stream is the CUDA per-thread default stream.
Definition: cuda_stream_view.hpp:117