device_scalar.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2026, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #pragma once
7 
9 #include <rmm/detail/export.hpp>
10 #include <rmm/device_uvector.hpp>
12 #include <rmm/resource_ref.hpp>
13 
14 #include <type_traits>
15 
16 namespace RMM_NAMESPACE {
30 template <typename T>
32  public:
33  static_assert(std::is_trivially_copyable_v<T>, "Scalar type must be trivially copyable");
34 
39  using pointer =
43 
44  RMM_EXEC_CHECK_DISABLE
45  ~device_scalar() = default;
46 
47  RMM_EXEC_CHECK_DISABLE
48  device_scalar(device_scalar&&) noexcept = default;
49 
55  device_scalar& operator=(device_scalar&&) noexcept = default;
56 
60  device_scalar(device_scalar const&) = delete;
61 
65  device_scalar& operator=(device_scalar const&) = delete;
66 
70  device_scalar() = delete;
71 
86  explicit device_scalar(
87  cuda_stream_view stream,
88  cuda::mr::any_resource<cuda::mr::device_accessible> mr = mr::get_current_device_resource_ref())
89  : _storage{1, stream, std::move(mr)}
90  {
91  }
92 
111  explicit device_scalar(
112  value_type const& initial_value,
113  cuda_stream_view stream,
114  cuda::mr::any_resource<cuda::mr::device_accessible> mr = mr::get_current_device_resource_ref())
115  : _storage{1, stream, std::move(mr)}
116  {
117  set_value_async(initial_value, stream);
118  }
119 
133  device_scalar const& other,
134  cuda_stream_view stream,
135  cuda::mr::any_resource<cuda::mr::device_accessible> mr = mr::get_current_device_resource_ref())
136  : _storage{other._storage, stream, std::move(mr)}
137  {
138  }
139 
156  [[nodiscard]] value_type value(cuda_stream_view stream) const
157  {
158  return _storage.front_element(stream);
159  }
160 
194  void set_value_async(value_type const& value, cuda_stream_view stream)
195  {
196  _storage.set_element_async(0, value, stream);
197  }
198 
199  // Disallow passing literals to set_value to avoid race conditions where the memory holding the
200  // literal can be freed before the async memcpy / memset executes.
201  void set_value_async(value_type&&, cuda_stream_view) = delete;
202 
218  {
219  _storage.set_element_to_zero_async(value_type{0}, stream);
220  }
221 
232  [[nodiscard]] pointer data() noexcept { return static_cast<pointer>(_storage.data()); }
233 
244  [[nodiscard]] const_pointer data() const noexcept
245  {
246  return static_cast<const_pointer>(_storage.data());
247  }
248 
252  [[nodiscard]] constexpr size_type size() const noexcept { return 1; }
253 
257  [[nodiscard]] cuda_stream_view stream() const noexcept { return _storage.stream(); }
258 
264  void set_stream(cuda_stream_view stream) noexcept { _storage.set_stream(stream); }
265 
266  private:
267  rmm::device_uvector<T> _storage;
268 };
269  // end of group
271 } // namespace RMM_NAMESPACE
Strongly-typed non-owning wrapper for CUDA streams with default constructor.
Definition: cuda_stream_view.hpp:28
Container for a single object of type T in device memory.
Definition: device_scalar.hpp:31
typename device_uvector< T >::value_type value_type
T, the type of the scalar element.
Definition: device_scalar.hpp:35
constexpr size_type size() const noexcept
The size of the scalar: always 1.
Definition: device_scalar.hpp:252
device_scalar(value_type const &initial_value, cuda_stream_view stream, cuda::mr::any_resource< cuda::mr::device_accessible > mr=mr::get_current_device_resource_ref())
Construct a new device_scalar with an initial value.
Definition: device_scalar.hpp:111
const_pointer data() const noexcept
Returns const pointer to object in device memory.
Definition: device_scalar.hpp:244
typename device_uvector< T >::const_reference const_reference
const value_type&
Definition: device_scalar.hpp:38
typename device_uvector< T >::size_type size_type
The type used for the size.
Definition: device_scalar.hpp:36
device_scalar(device_scalar &&) noexcept=default
Default move constructor.
device_scalar(device_scalar const &other, cuda_stream_view stream, cuda::mr::any_resource< cuda::mr::device_accessible > mr=mr::get_current_device_resource_ref())
Construct a new device_scalar by deep copying the contents of another device_scalar,...
Definition: device_scalar.hpp:132
void set_stream(cuda_stream_view stream) noexcept
Sets the stream to be used for deallocation.
Definition: device_scalar.hpp:264
typename device_uvector< T >::const_pointer const_pointer
Definition: device_scalar.hpp:42
cuda_stream_view stream() const noexcept
Stream associated with the device memory allocation.
Definition: device_scalar.hpp:257
pointer data() noexcept
Returns pointer to object in device memory.
Definition: device_scalar.hpp:232
value_type value(cuda_stream_view stream) const
Copies the value from device to host, synchronizes, and returns the value.
Definition: device_scalar.hpp:156
void set_value_to_zero_async(cuda_stream_view stream)
Sets the value of the device_scalar to zero on the specified stream.
Definition: device_scalar.hpp:217
void set_value_async(value_type const &value, cuda_stream_view stream)
Sets the value of the device_scalar to the value of v.
Definition: device_scalar.hpp:194
typename device_uvector< T >::pointer pointer
The type of the pointer returned by data()
Definition: device_scalar.hpp:40
typename device_uvector< T >::reference reference
value_type&
Definition: device_scalar.hpp:37
An uninitialized vector of elements in device memory.
Definition: device_uvector.hpp:68
value_type * pointer
The type of the pointer returned by data()
Definition: device_uvector.hpp:78
std::size_t size_type
The type used for the size of the vector.
Definition: device_uvector.hpp:74
T value_type
Stored value type.
Definition: device_uvector.hpp:73
value_type & reference
Reference type returned by operator[](size_type)
Definition: device_uvector.hpp:75
value_type const * const_pointer
The type of the pointer returned by data() const.
Definition: device_uvector.hpp:79
value_type const & const_reference
Constant reference type returned by operator[](size_type) const.
Definition: device_uvector.hpp:77
device_async_resource_ref get_current_device_resource_ref()
Get the device_async_resource_ref for the current device.
Definition: per_device_resource.hpp:223
Management of per-device memory resources.