cuda_async_managed_memory_resource.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
7 #include <rmm/cuda_device.hpp>
9 #include <rmm/detail/error.hpp>
10 #include <rmm/detail/export.hpp>
11 #include <rmm/detail/runtime_capabilities.hpp>
12 #include <rmm/detail/thrust_namespace.h>
15 
16 #include <cuda/std/type_traits>
17 #include <cuda_runtime_api.h>
18 
19 #include <cstddef>
20 #include <cstdint>
21 #include <optional>
22 
23 namespace RMM_NAMESPACE {
24 namespace mr {
37  public:
49  {
50  // Check if managed memory pools are supported
51  RMM_EXPECTS(rmm::detail::runtime_async_managed_alloc::is_supported(),
52  "cuda_async_managed_memory_resource requires CUDA 13.0 or higher");
53 
54 #if defined(CUDA_VERSION) && CUDA_VERSION >= RMM_MIN_ASYNC_MANAGED_ALLOC_CUDA_VERSION
55  cudaMemPool_t managed_pool_handle{};
56  cudaMemLocation location{.type = cudaMemLocationTypeDevice,
58  RMM_CUDA_TRY(
59  cudaMemGetDefaultMemPool(&managed_pool_handle, &location, cudaMemAllocationTypeManaged));
60  pool_ = cuda_async_view_memory_resource{managed_pool_handle};
61 #endif
62  }
63 
69  [[nodiscard]] cudaMemPool_t pool_handle() const noexcept { return pool_.pool_handle(); }
70 
72  cuda_async_managed_memory_resource(cuda_async_managed_memory_resource const&) = delete;
73  cuda_async_managed_memory_resource(cuda_async_managed_memory_resource&&) = delete;
74  cuda_async_managed_memory_resource& operator=(cuda_async_managed_memory_resource const&) = delete;
75  cuda_async_managed_memory_resource& operator=(cuda_async_managed_memory_resource&&) = delete;
76 
77  private:
78  cuda_async_view_memory_resource pool_{};
79 
89  void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override
90  {
91  return pool_.allocate(stream, bytes);
92  }
93 
102  void do_deallocate(void* ptr, std::size_t bytes, rmm::cuda_stream_view stream) noexcept override
103  {
104  pool_.deallocate(stream, ptr, bytes);
105  }
106 
114  [[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
115  {
116  auto const* async_mr = dynamic_cast<cuda_async_managed_memory_resource const*>(&other);
117  return (async_mr != nullptr) && (this->pool_handle() == async_mr->pool_handle());
118  }
119 };
120 
121 // static property checks
122 static_assert(rmm::detail::polyfill::resource<cuda_async_managed_memory_resource>);
123 static_assert(rmm::detail::polyfill::async_resource<cuda_async_managed_memory_resource>);
124 static_assert(rmm::detail::polyfill::resource_with<cuda_async_managed_memory_resource,
125  cuda::mr::device_accessible>);
126 static_assert(rmm::detail::polyfill::async_resource_with<cuda_async_managed_memory_resource,
127  cuda::mr::device_accessible>);
128  // end of group
130 } // namespace mr
131 } // namespace RMM_NAMESPACE
Strongly-typed non-owning wrapper for CUDA streams with default constructor.
Definition: cuda_stream_view.hpp:28
device_memory_resource derived class that uses cudaMallocFromPoolAsync/cudaFreeFromPoolAsync with a m...
Definition: cuda_async_managed_memory_resource.hpp:36
cuda_async_managed_memory_resource()
Constructs a cuda_async_managed_memory_resource with the default managed memory pool for the current ...
Definition: cuda_async_managed_memory_resource.hpp:48
cudaMemPool_t pool_handle() const noexcept
Returns the underlying native handle to the CUDA pool.
Definition: cuda_async_managed_memory_resource.hpp:69
device_memory_resource derived class that uses cudaMallocAsync/cudaFreeAsync for allocation/deallocat...
Definition: cuda_async_view_memory_resource.hpp:30
Base class for all librmm device memory allocation.
Definition: device_memory_resource.hpp:83
void * allocate(cuda_stream_view stream, std::size_t bytes, std::size_t alignment=rmm::CUDA_ALLOCATION_ALIGNMENT)
Allocates memory of size at least bytes on the specified stream.
Definition: device_memory_resource.hpp:322
cuda_device_id get_current_cuda_device()
Returns a cuda_device_id for the current device.
constexpr value_type value() const noexcept
The wrapped integer value.
Definition: cuda_device.hpp:43