All Classes Files Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages
cuda_async_view_memory_resource.hpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2021-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include <rmm/cuda_device.hpp>
19 #include <rmm/cuda_stream_view.hpp>
20 #include <rmm/detail/error.hpp>
21 #include <rmm/detail/export.hpp>
22 #include <rmm/detail/thrust_namespace.h>
24 
25 #include <cuda_runtime_api.h>
26 
27 #include <cstddef>
28 
29 namespace RMM_NAMESPACE {
30 namespace mr {
42  public:
53  cuda_async_view_memory_resource(cudaMemPool_t valid_pool_handle)
54  : cuda_pool_handle_{[valid_pool_handle]() {
55  RMM_EXPECTS(nullptr != valid_pool_handle, "Unexpected null pool handle.");
56  return valid_pool_handle;
57  }()}
58  {
59  // Check if cudaMallocAsync Memory pool supported
60  auto const device = rmm::get_current_cuda_device();
61  int cuda_pool_supported{};
62  auto result =
63  cudaDeviceGetAttribute(&cuda_pool_supported, cudaDevAttrMemoryPoolsSupported, device.value());
64  RMM_EXPECTS(result == cudaSuccess && cuda_pool_supported,
65  "cudaMallocAsync not supported with this CUDA driver/runtime version");
66  }
67 
73  [[nodiscard]] cudaMemPool_t pool_handle() const noexcept { return cuda_pool_handle_; }
74 
77  default;
79  default;
81  default;
83  default;
84 
85  private:
86  cudaMemPool_t cuda_pool_handle_{};
87 
97  void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override
98  {
99  void* ptr{nullptr};
100  if (bytes > 0) {
101  RMM_CUDA_TRY_ALLOC(cudaMallocFromPoolAsync(&ptr, bytes, pool_handle(), stream.value()),
102  bytes);
103  }
104  return ptr;
105  }
106 
115  void do_deallocate(void* ptr,
116  [[maybe_unused]] std::size_t bytes,
117  rmm::cuda_stream_view stream) override
118  {
119  if (ptr != nullptr) { RMM_ASSERT_CUDA_SUCCESS(cudaFreeAsync(ptr, stream.value())); }
120  }
121 
129  [[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
130  {
131  return dynamic_cast<cuda_async_view_memory_resource const*>(&other) != nullptr;
132  }
133 };
134  // end of group
136 } // namespace mr
137 } // namespace RMM_NAMESPACE
Strongly-typed non-owning wrapper for CUDA streams with default constructor.
Definition: cuda_stream_view.hpp:39
constexpr cudaStream_t value() const noexcept
Get the wrapped stream.
Definition: cuda_stream_view.hpp:73
device_memory_resource derived class that uses cudaMallocAsync/cudaFreeAsync for allocation/deallocat...
Definition: cuda_async_view_memory_resource.hpp:41
cuda_async_view_memory_resource & operator=(cuda_async_view_memory_resource &&)=default
Default move assignment operator.
cuda_async_view_memory_resource(cudaMemPool_t valid_pool_handle)
Constructs a cuda_async_view_memory_resource which uses an existing CUDA memory pool....
Definition: cuda_async_view_memory_resource.hpp:53
cuda_async_view_memory_resource(cuda_async_view_memory_resource &&)=default
Default move constructor.
cudaMemPool_t pool_handle() const noexcept
Returns the underlying native handle to the CUDA pool.
Definition: cuda_async_view_memory_resource.hpp:73
cuda_async_view_memory_resource(cuda_async_view_memory_resource const &)=default
Default copy constructor.
cuda_async_view_memory_resource & operator=(cuda_async_view_memory_resource const &)=default
Default copy assignment operator.
Base class for all librmm device memory allocation.
Definition: device_memory_resource.hpp:93
cuda_device_id get_current_cuda_device()
Returns a cuda_device_id for the current device.
Definition: cuda_device.hpp:99