callback_memory_resource.hpp
1 /*
2  * Copyright (c) 2022, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include <rmm/mr/device/device_memory_resource.hpp>
19 
20 #include <cstddef>
21 #include <functional>
22 #include <utility>
23 
24 namespace rmm::mr {
25 
41 using allocate_callback_t = std::function<void*(std::size_t, cuda_stream_view, void*)>;
42 
59 using deallocate_callback_t = std::function<void(void*, std::size_t, cuda_stream_view, void*)>;
60 
66  public:
82  callback_memory_resource(allocate_callback_t allocate_callback,
83  deallocate_callback_t deallocate_callback,
84  void* allocate_callback_arg = nullptr,
85  void* deallocate_callback_arg = nullptr) noexcept
86  : allocate_callback_(allocate_callback),
87  deallocate_callback_(deallocate_callback),
88  allocate_callback_arg_(allocate_callback_arg),
89  deallocate_callback_arg_(deallocate_callback_arg)
90  {
91  }
92 
93  callback_memory_resource() = delete;
94  ~callback_memory_resource() override = default;
96  callback_memory_resource& operator=(callback_memory_resource const&) = delete;
98  callback_memory_resource& operator=(callback_memory_resource&&) noexcept = default;
99 
100  private:
101  void* do_allocate(std::size_t bytes, cuda_stream_view stream) override
102  {
103  return allocate_callback_(bytes, stream, allocate_callback_arg_);
104  }
105 
106  void do_deallocate(void* ptr, std::size_t bytes, cuda_stream_view stream) override
107  {
108  deallocate_callback_(ptr, bytes, stream, deallocate_callback_arg_);
109  }
110 
111  [[nodiscard]] std::pair<std::size_t, std::size_t> do_get_mem_info(cuda_stream_view) const override
112  {
113  throw std::runtime_error("cannot get free / total memory");
114  }
115 
116  [[nodiscard]] virtual bool supports_streams() const noexcept { return false; }
117  [[nodiscard]] virtual bool supports_get_mem_info() const noexcept { return false; }
118 
119  allocate_callback_t allocate_callback_;
120  deallocate_callback_t deallocate_callback_;
121  void* allocate_callback_arg_;
122  void* deallocate_callback_arg_;
123 };
124 
125 } // namespace rmm::mr
rmm::mr::callback_memory_resource
A device memory resource that uses the provided callbacks for memory allocation and deallocation.
Definition: callback_memory_resource.hpp:65
rmm::cuda_stream_view
Strongly-typed non-owning wrapper for CUDA streams with default constructor.
Definition: cuda_stream_view.hpp:34
rmm::mr::callback_memory_resource::callback_memory_resource
callback_memory_resource(allocate_callback_t allocate_callback, deallocate_callback_t deallocate_callback, void *allocate_callback_arg=nullptr, void *deallocate_callback_arg=nullptr) noexcept
Construct a new callback memory resource.
Definition: callback_memory_resource.hpp:82
rmm::mr::device_memory_resource
Base class for all libcudf device memory allocation.
Definition: device_memory_resource.hpp:82