18 #include <rmm/detail/error.hpp>
19 #include <rmm/logger.hpp>
20 #include <rmm/mr/device/detail/arena.hpp>
21 #include <rmm/mr/device/device_memory_resource.hpp>
23 #include <cuda_runtime_api.h>
25 #include <spdlog/common.h>
26 #include <spdlog/fmt/ostr.h>
30 #include <shared_mutex>
74 template <
typename Upstream>
88 std::optional<std::size_t> arena_size = std::nullopt,
89 bool dump_log_on_failure =
false)
90 : global_arena_{upstream_mr, arena_size}, dump_log_on_failure_{dump_log_on_failure}
92 if (dump_log_on_failure_) {
93 logger_ = spdlog::basic_logger_mt(
"arena_memory_dump",
"rmm_arena_memory_dump.log");
95 logger_->set_level(spdlog::level::info);
139 if (bytes <= 0) {
return nullptr; }
140 #ifdef RMM_ARENA_USE_SIZE_CLASSES
141 bytes = rmm::mr::detail::arena::align_to_size_class(bytes);
143 bytes = rmm::detail::align_up(bytes, rmm::detail::CUDA_ALLOCATION_ALIGNMENT);
145 auto& arena = get_arena(stream);
148 std::shared_lock lock(mtx_);
149 void* pointer = arena.allocate(bytes);
150 if (pointer !=
nullptr) {
return pointer; }
154 std::unique_lock lock(mtx_);
156 void* pointer = arena.allocate(bytes);
157 if (pointer ==
nullptr) {
158 if (dump_log_on_failure_) { dump_memory_log(bytes); }
170 RMM_CUDA_TRY(cudaDeviceSynchronize());
171 for (
auto& thread_arena : thread_arenas_) {
172 thread_arena.second->clean();
174 for (
auto& stream_arena : stream_arenas_) {
175 stream_arena.second.clean();
187 void do_deallocate(
void* ptr, std::size_t bytes, cuda_stream_view stream)
override
189 if (ptr ==
nullptr || bytes <= 0) {
return; }
190 #ifdef RMM_ARENA_USE_SIZE_CLASSES
191 bytes = rmm::mr::detail::arena::align_to_size_class(bytes);
193 bytes = rmm::detail::align_up(bytes, rmm::detail::CUDA_ALLOCATION_ALIGNMENT);
195 auto& arena = get_arena(stream);
198 std::shared_lock lock(mtx_);
200 if (arena.deallocate(ptr, bytes, stream)) {
return; }
206 stream.synchronize_no_throw();
208 std::unique_lock lock(mtx_);
209 deallocate_from_other_arena(ptr, bytes, stream);
221 void deallocate_from_other_arena(
void* ptr, std::size_t bytes, cuda_stream_view stream)
223 if (use_per_thread_arena(stream)) {
224 for (
auto const& thread_arena : thread_arenas_) {
225 if (thread_arena.second->deallocate(ptr, bytes)) {
return; }
228 for (
auto& stream_arena : stream_arenas_) {
229 if (stream_arena.second.deallocate(ptr, bytes)) {
return; }
233 if (!global_arena_.
deallocate(ptr, bytes)) { RMM_FAIL(
"allocation not found"); }
242 arena& get_arena(cuda_stream_view stream)
244 if (use_per_thread_arena(stream)) {
return get_thread_arena(); }
245 return get_stream_arena(stream);
253 arena& get_thread_arena()
255 auto const thread_id = std::this_thread::get_id();
257 std::shared_lock lock(map_mtx_);
258 auto const iter = thread_arenas_.find(thread_id);
259 if (iter != thread_arenas_.end()) {
return *iter->second; }
262 std::unique_lock lock(map_mtx_);
263 auto thread_arena = std::make_shared<arena>(global_arena_);
264 thread_arenas_.emplace(thread_id, thread_arena);
265 thread_local detail::arena::arena_cleaner<Upstream> cleaner{thread_arena};
266 return *thread_arena;
275 arena& get_stream_arena(cuda_stream_view stream)
277 RMM_LOGGING_ASSERT(!use_per_thread_arena(stream));
279 std::shared_lock lock(map_mtx_);
280 auto const iter = stream_arenas_.find(stream.value());
281 if (iter != stream_arenas_.end()) {
return iter->second; }
284 std::unique_lock lock(map_mtx_);
285 stream_arenas_.emplace(stream.value(), global_arena_);
286 return stream_arenas_.at(stream.value());
296 std::pair<std::size_t, std::size_t> do_get_mem_info(cuda_stream_view stream)
const override
298 return std::make_pair(0, 0);
306 void dump_memory_log(
size_t bytes)
308 logger_->info(
"**************************************************");
310 logger_->info(
"**************************************************");
311 logger_->info(
"Global arena:");
322 static bool use_per_thread_arena(cuda_stream_view stream)
324 return stream.is_per_thread_default();
328 global_arena global_arena_;
331 std::map<std::thread::id, std::shared_ptr<arena>> thread_arenas_;
334 std::map<cudaStream_t, arena> stream_arenas_;
336 bool dump_log_on_failure_{};
338 std::shared_ptr<spdlog::logger> logger_{};
340 mutable std::shared_mutex map_mtx_;
342 mutable std::shared_mutex mtx_;