All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
cuda_stream_view.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2020-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <rmm/detail/error.hpp>
20 #include <rmm/detail/export.hpp>
21 
22 #include <cuda/stream_ref>
23 #include <cuda_runtime_api.h>
24 
25 #include <cstddef>
26 
27 namespace RMM_NAMESPACE {
40  public:
41  constexpr cuda_stream_view() = default;
42  ~cuda_stream_view() = default;
43  constexpr cuda_stream_view(cuda_stream_view const&) = default;
44  constexpr cuda_stream_view(cuda_stream_view&&) = default;
46  default;
48  default;
49 
50  // Disable construction from literal 0
51  constexpr cuda_stream_view(int) = delete; //< Prevent cast from 0
52  constexpr cuda_stream_view(std::nullptr_t) = delete; //< Prevent cast from nullptr
53 
59  constexpr cuda_stream_view(cudaStream_t stream) noexcept : stream_{stream} {}
60 
66  constexpr cuda_stream_view(cuda::stream_ref stream) noexcept : stream_{stream.get()} {}
67 
73  [[nodiscard]] constexpr cudaStream_t value() const noexcept { return stream_; }
74 
80  constexpr operator cudaStream_t() const noexcept { return value(); }
81 
87  constexpr operator cuda::stream_ref() const noexcept { return value(); }
88 
92  [[nodiscard]] inline bool is_per_thread_default() const noexcept;
93 
97  [[nodiscard]] inline bool is_default() const noexcept;
98 
106  void synchronize() const { RMM_CUDA_TRY(cudaStreamSynchronize(stream_)); }
107 
113  void synchronize_no_throw() const noexcept
114  {
115  RMM_ASSERT_CUDA_SUCCESS(cudaStreamSynchronize(stream_));
116  }
117 
118  private:
119  cudaStream_t stream_{};
120 };
121 
126 
132  cudaStreamLegacy // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
133 };
134 
139  cudaStreamPerThread // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
140 };
141 
142 // Need to avoid putting is_per_thread_default and is_default into the group twice. // end of group
144 
145 [[nodiscard]] inline bool cuda_stream_view::is_per_thread_default() const noexcept
146 {
147 #ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
148  return value() == cuda_stream_per_thread || value() == nullptr;
149 #else
150  return value() == cuda_stream_per_thread;
151 #endif
152 }
153 
154 [[nodiscard]] inline bool cuda_stream_view::is_default() const noexcept
155 {
156 #ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
157  return value() == cuda_stream_legacy;
158 #else
159  return value() == cuda_stream_legacy || value() == nullptr;
160 #endif
161 }
162 
176 {
177  return lhs.value() == rhs.value();
178 }
179 
187 inline bool operator!=(cuda_stream_view lhs, cuda_stream_view rhs) { return not(lhs == rhs); }
188 
196 inline std::ostream& operator<<(std::ostream& os, cuda_stream_view stream)
197 {
198  os << stream.value();
199  return os;
200 }
201  // end of group
203 } // namespace RMM_NAMESPACE
Strongly-typed non-owning wrapper for CUDA streams with default constructor.
Definition: cuda_stream_view.hpp:39
constexpr cudaStream_t value() const noexcept
Get the wrapped stream.
Definition: cuda_stream_view.hpp:73
constexpr cuda_stream_view & operator=(cuda_stream_view const &)=default
Default copy assignment operator.
constexpr cuda_stream_view & operator=(cuda_stream_view &&)=default
Default move assignment operator.
constexpr cuda_stream_view(cudaStream_t stream) noexcept
Constructor from a cudaStream_t.
Definition: cuda_stream_view.hpp:59
constexpr cuda_stream_view(cuda::stream_ref stream) noexcept
Implicit conversion from stream_ref.
Definition: cuda_stream_view.hpp:66
constexpr cuda_stream_view(cuda_stream_view &&)=default
Default move constructor.
void synchronize_no_throw() const noexcept
Synchronize the viewed CUDA stream. Does not throw if there is an error.
Definition: cuda_stream_view.hpp:113
constexpr cuda_stream_view(cuda_stream_view const &)=default
Default copy constructor.
bool operator==(cuda_stream_view lhs, cuda_stream_view rhs)
Equality comparison operator for streams.
Definition: cuda_stream_view.hpp:175
static constexpr cuda_stream_view cuda_stream_default
Static cuda_stream_view of the default stream (stream 0), for convenience.
Definition: cuda_stream_view.hpp:125
std::ostream & operator<<(std::ostream &os, cuda_stream_view stream)
Output stream operator for printing / logging streams.
Definition: cuda_stream_view.hpp:196
static const cuda_stream_view cuda_stream_per_thread
Static cuda_stream_view of cudaStreamPerThread, for convenience.
Definition: cuda_stream_view.hpp:138
bool operator!=(cuda_stream_view lhs, cuda_stream_view rhs)
Inequality comparison operator for streams.
Definition: cuda_stream_view.hpp:187
static const cuda_stream_view cuda_stream_legacy
Static cuda_stream_view of cudaStreamLegacy, for convenience.
Definition: cuda_stream_view.hpp:131