managed_memory_resource.hpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2019-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
8 #include <rmm/detail/error.hpp>
9 #include <rmm/detail/export.hpp>
11 
12 #include <cstddef>
13 
14 namespace RMM_NAMESPACE {
15 namespace mr {
26  public:
27  managed_memory_resource() = default;
28  ~managed_memory_resource() override = default;
32  default;
34  default;
35 
36  private:
48  void* do_allocate(std::size_t bytes, [[maybe_unused]] cuda_stream_view stream) override
49  {
50  // FIXME: Unlike cudaMalloc, cudaMallocManaged will throw an error for 0
51  // size allocations.
52  if (bytes == 0) { return nullptr; }
53 
54  void* ptr{nullptr};
55  RMM_CUDA_TRY_ALLOC(cudaMallocManaged(&ptr, bytes), bytes);
56  return ptr;
57  }
58 
69  void do_deallocate(void* ptr,
70  [[maybe_unused]] std::size_t bytes,
71  [[maybe_unused]] cuda_stream_view stream) noexcept override
72  {
73  RMM_ASSERT_CUDA_SUCCESS(cudaFree(ptr));
74  }
75 
86  [[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
87  {
88  return dynamic_cast<managed_memory_resource const*>(&other) != nullptr;
89  }
90 };
91  // end of group
93 } // namespace mr
94 } // namespace RMM_NAMESPACE
Strongly-typed non-owning wrapper for CUDA streams with default constructor.
Definition: cuda_stream_view.hpp:28
Base class for all librmm device memory allocation.
Definition: device_memory_resource.hpp:83
device_memory_resource derived class that uses cudaMallocManaged/Free for allocation/deallocation.
Definition: managed_memory_resource.hpp:25
managed_memory_resource(managed_memory_resource &&)=default
Default move constructor.
managed_memory_resource & operator=(managed_memory_resource &&)=default
Default move assignment operator.
managed_memory_resource & operator=(managed_memory_resource const &)=default
Default copy assignment operator.
managed_memory_resource(managed_memory_resource const &)=default
Default copy constructor.