cuda_stream_view.hpp
1 /*
2  * Copyright (c) 2020, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <rmm/detail/error.hpp>
20 
21 #include <cuda_runtime_api.h>
22 
23 #include <atomic>
24 #include <cstddef>
25 #include <cstdint>
26 
27 namespace rmm {
28 
35  public:
36  constexpr cuda_stream_view() = default;
37  constexpr cuda_stream_view(cuda_stream_view const&) = default;
38  constexpr cuda_stream_view(cuda_stream_view&&) = default;
39  constexpr cuda_stream_view& operator=(cuda_stream_view const&) = default;
40  constexpr cuda_stream_view& operator=(cuda_stream_view&&) = default;
41  ~cuda_stream_view() = default;
42 
43  // TODO disable construction from 0 after cuDF and others adopt cuda_stream_view
44  // cuda_stream_view(int) = delete; //< Prevent cast from 0
45  // cuda_stream_view(std::nullptr_t) = delete; //< Prevent cast from nullptr
46  // TODO also disable implicit conversion from cudaStream_t
47 
51  constexpr cuda_stream_view(cudaStream_t stream) noexcept : stream_{stream} {}
52 
58  constexpr cudaStream_t value() const noexcept { return stream_; }
59 
63  explicit constexpr operator cudaStream_t() const noexcept { return value(); }
64 
68  bool is_per_thread_default() const noexcept
69  {
70 #ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
71  return value() == cudaStreamPerThread || value() == 0;
72 #else
73  return value() == cudaStreamPerThread;
74 #endif
75  }
76 
80  bool is_default() const noexcept
81  {
82 #ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
83  return value() == cudaStreamLegacy;
84 #else
85  return value() == cudaStreamLegacy || value() == 0;
86 #endif
87  }
88 
96  void synchronize() const { RMM_CUDA_TRY(cudaStreamSynchronize(stream_)); }
97 
103  void synchronize_no_throw() const noexcept
104  {
105  RMM_ASSERT_CUDA_SUCCESS(cudaStreamSynchronize(stream_));
106  }
107 
108  private:
109  cudaStream_t stream_{0};
110 };
111 
115 static constexpr cuda_stream_view cuda_stream_default{};
116 
120 static cuda_stream_view cuda_stream_legacy{cudaStreamLegacy};
121 
125 static cuda_stream_view cuda_stream_per_thread{cudaStreamPerThread};
126 
134 inline bool operator==(cuda_stream_view lhs, cuda_stream_view rhs)
135 {
136  return lhs.value() == rhs.value();
137 }
138 
146 inline bool operator!=(cuda_stream_view lhs, cuda_stream_view rhs) { return not(lhs == rhs); }
147 
155 inline std::ostream& operator<<(std::ostream& os, cuda_stream_view sv)
156 {
157  os << sv.value();
158  return os;
159 }
160 
161 } // namespace rmm
rmm::cuda_stream_view::is_default
bool is_default() const noexcept
Return true if the wrapped stream is explicitly the CUDA legacy default stream.
Definition: cuda_stream_view.hpp:80
rmm::cuda_stream_view
Strongly-typed non-owning wrapper for CUDA streams with default constructor.
Definition: cuda_stream_view.hpp:34
rmm::cuda_stream_view::cuda_stream_view
constexpr cuda_stream_view(cudaStream_t stream) noexcept
Implicit conversion from cudaStream_t.
Definition: cuda_stream_view.hpp:51
rmm::cuda_stream_view::synchronize
void synchronize() const
Synchronize the viewed CUDA stream.
Definition: cuda_stream_view.hpp:96
rmm::cuda_stream_view::value
constexpr cudaStream_t value() const noexcept
Get the wrapped stream.
Definition: cuda_stream_view.hpp:58
rmm::cuda_stream_view::synchronize_no_throw
void synchronize_no_throw() const noexcept
Synchronize the viewed CUDA stream. Does not throw if there is an error.
Definition: cuda_stream_view.hpp:103
rmm::cuda_stream_view::is_per_thread_default
bool is_per_thread_default() const noexcept
Return true if the wrapped stream is the CUDA per-thread default stream.
Definition: cuda_stream_view.hpp:68