dynamic_load_runtime.hpp
1 /*
2  * Copyright (c) 2022, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include <rmm/cuda_device.hpp>
19 
20 #include <cuda_runtime_api.h>
21 
22 #include <dlfcn.h>
23 
24 #include <memory>
25 #include <optional>
26 
27 namespace rmm::detail {
28 
36  static void* get_cuda_runtime_handle()
37  {
38  auto close_cudart = [](void* handle) { ::dlclose(handle); };
39  auto open_cudart = []() {
40  ::dlerror();
41  const int major = CUDART_VERSION / 1000;
42  const std::string libname_ver = "libcudart.so." + std::to_string(major) + ".0";
43  const std::string libname = "libcudart.so";
44 
45  auto ptr = ::dlopen(libname_ver.c_str(), RTLD_LAZY);
46  if (!ptr) { ptr = ::dlopen(libname.c_str(), RTLD_LAZY); }
47  if (ptr) { return ptr; }
48 
49  RMM_FAIL("Unable to dlopen cudart");
50  };
51  static std::unique_ptr<void, decltype(close_cudart)> cudart_handle{open_cudart(), close_cudart};
52  return cudart_handle.get();
53  }
54 
55  template <typename... Args>
56  using function_sig = std::add_pointer_t<cudaError_t(Args...)>;
57 
58  template <typename signature>
59  static std::optional<signature> function(const char* func_name)
60  {
61  auto* runtime = get_cuda_runtime_handle();
62  auto* handle = ::dlsym(runtime, func_name);
63  if (!handle) { return std::nullopt; }
64  auto* function_ptr = reinterpret_cast<signature>(handle);
65  return std::optional<signature>(function_ptr);
66  }
67 };
68 
69 #if defined(RMM_STATIC_CUDART)
70 // clang-format off
71 #define RMM_CUDART_API_WRAPPER(name, signature) \
72  template <typename... Args> \
73  static cudaError_t name(Args... args) \
74  { \
75  _Pragma("GCC diagnostic push") \
76  _Pragma("GCC diagnostic ignored \"-Waddress\"") \
77  static_assert(static_cast<signature>(::name), \
78  "Failed to find #name function with arguments #signature"); \
79  _Pragma("GCC diagnostic pop") \
80  return ::name(args...); \
81  }
82 // clang-format on
83 #else
84 #define RMM_CUDART_API_WRAPPER(name, signature) \
85  template <typename... Args> \
86  static cudaError_t name(Args... args) \
87  { \
88  static const auto func = dynamic_load_runtime::function<signature>(#name); \
89  if (func) { return (*func)(args...); } \
90  RMM_FAIL("Failed to find #name function in libcudart.so"); \
91  }
92 #endif
93 
94 #if CUDART_VERSION >= 11020 // 11.2 introduced cudaMallocAsync
95 
102 struct async_alloc {
103  static bool is_supported()
104  {
105 #if defined(RMM_STATIC_CUDART)
106  static bool runtime_supports_pool = (CUDART_VERSION >= 11020);
107 #else
108  static bool runtime_supports_pool =
109  dynamic_load_runtime::function<dynamic_load_runtime::function_sig<void*, cudaStream_t>>(
110  "cudaFreeAsync")
111  .has_value();
112 #endif
113 
114  static auto driver_supports_pool{[] {
115  int cuda_pool_supported{};
116  auto result = cudaDeviceGetAttribute(&cuda_pool_supported,
117  cudaDevAttrMemoryPoolsSupported,
118  rmm::detail::current_device().value());
119  return result == cudaSuccess and cuda_pool_supported == 1;
120  }()};
121  return runtime_supports_pool and driver_supports_pool;
122  }
123 
135  static bool is_export_handle_type_supported(cudaMemAllocationHandleType handle_type)
136  {
137  int supported_handle_types_bitmask{};
138 #if CUDART_VERSION >= 11030 // 11.3 introduced cudaDevAttrMemoryPoolSupportedHandleTypes
139  if (cudaMemHandleTypeNone != handle_type) {
140  auto const result = cudaDeviceGetAttribute(&supported_handle_types_bitmask,
141  cudaDevAttrMemoryPoolSupportedHandleTypes,
142  rmm::detail::current_device().value());
143 
144  // Don't throw on cudaErrorInvalidValue
145  auto const unsupported_runtime = (result == cudaErrorInvalidValue);
146  if (unsupported_runtime) return false;
147  // throw any other error that may have occurred
148  RMM_CUDA_TRY(result);
149  }
150 
151 #endif
152  return (supported_handle_types_bitmask & handle_type) == handle_type;
153  }
154 
155  template <typename... Args>
156  using cudart_sig = dynamic_load_runtime::function_sig<Args...>;
157 
158  using cudaMemPoolCreate_sig = cudart_sig<cudaMemPool_t*, const cudaMemPoolProps*>;
159  RMM_CUDART_API_WRAPPER(cudaMemPoolCreate, cudaMemPoolCreate_sig);
160 
161  using cudaMemPoolSetAttribute_sig = cudart_sig<cudaMemPool_t, cudaMemPoolAttr, void*>;
162  RMM_CUDART_API_WRAPPER(cudaMemPoolSetAttribute, cudaMemPoolSetAttribute_sig);
163 
164  using cudaMemPoolDestroy_sig = cudart_sig<cudaMemPool_t>;
165  RMM_CUDART_API_WRAPPER(cudaMemPoolDestroy, cudaMemPoolDestroy_sig);
166 
167  using cudaMallocFromPoolAsync_sig = cudart_sig<void**, size_t, cudaMemPool_t, cudaStream_t>;
168  RMM_CUDART_API_WRAPPER(cudaMallocFromPoolAsync, cudaMallocFromPoolAsync_sig);
169 
170  using cudaFreeAsync_sig = cudart_sig<void*, cudaStream_t>;
171  RMM_CUDART_API_WRAPPER(cudaFreeAsync, cudaFreeAsync_sig);
172 
173  using cudaDeviceGetDefaultMemPool_sig = cudart_sig<cudaMemPool_t*, int>;
174  RMM_CUDART_API_WRAPPER(cudaDeviceGetDefaultMemPool, cudaDeviceGetDefaultMemPool_sig);
175 };
176 #endif
177 
178 #undef RMM_CUDART_API_WRAPPER
179 } // namespace rmm::detail
rmm::detail::dynamic_load_runtime
dynamic_load_runtime loads the cuda runtime library at runtime
Definition: dynamic_load_runtime.hpp:35