cuda.hpp
1 /*
2  * Copyright (c) 2022-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include <any>
19 #include <functional>
20 
21 #include <kvikio/shim/cuda_h_wrapper.hpp>
22 #include <kvikio/shim/utils.hpp>
23 #include <stdexcept>
24 
25 namespace kvikio {
26 
27 namespace detail {
31 class AnyCallable {
32  private:
33  std::any _callable;
34 
35  public:
42  template <typename Callable>
43  void set(Callable&& c)
44  {
45  _callable = std::function(c);
46  }
47 
51  void reset() { _callable.reset(); }
52 
64  template <typename... Args>
65  CUresult operator()(Args... args)
66  {
67  using T = std::function<CUresult(Args...)>;
68  if (!_callable.has_value()) {
69  throw std::runtime_error("No callable has been assigned to the wrapper yet.");
70  }
71  return std::any_cast<T&>(_callable)(args...);
72  }
73 
77  operator bool() const { return _callable.has_value(); }
78 };
79 
80 } // namespace detail
81 
89 class cudaAPI {
90  public:
91  int driver_version{0};
92 
93  decltype(cuInit)* Init{nullptr};
94  decltype(cuMemHostAlloc)* MemHostAlloc{nullptr};
95  decltype(cuMemFreeHost)* MemFreeHost{nullptr};
96  decltype(cuMemcpyHtoDAsync)* MemcpyHtoDAsync{nullptr};
97  decltype(cuMemcpyDtoHAsync)* MemcpyDtoHAsync{nullptr};
98 
99  detail::AnyCallable MemcpyBatchAsync{};
100 
101  decltype(cuPointerGetAttribute)* PointerGetAttribute{nullptr};
102  decltype(cuPointerGetAttributes)* PointerGetAttributes{nullptr};
103  decltype(cuCtxPushCurrent)* CtxPushCurrent{nullptr};
104  decltype(cuCtxPopCurrent)* CtxPopCurrent{nullptr};
105  decltype(cuCtxGetCurrent)* CtxGetCurrent{nullptr};
106  decltype(cuCtxGetDevice)* CtxGetDevice{nullptr};
107  decltype(cuMemGetAddressRange)* MemGetAddressRange{nullptr};
108  decltype(cuGetErrorName)* GetErrorName{nullptr};
109  decltype(cuGetErrorString)* GetErrorString{nullptr};
110  decltype(cuDeviceGet)* DeviceGet{nullptr};
111  decltype(cuDeviceGetCount)* DeviceGetCount{nullptr};
112  decltype(cuDeviceGetAttribute)* DeviceGetAttribute{nullptr};
113  decltype(cuDevicePrimaryCtxRetain)* DevicePrimaryCtxRetain{nullptr};
114  decltype(cuDevicePrimaryCtxRelease)* DevicePrimaryCtxRelease{nullptr};
115  decltype(cuStreamSynchronize)* StreamSynchronize{nullptr};
116  decltype(cuStreamCreate)* StreamCreate{nullptr};
117  decltype(cuStreamDestroy)* StreamDestroy{nullptr};
118  decltype(cuDriverGetVersion)* DriverGetVersion{nullptr};
119 
120  private:
121  cudaAPI();
122 
123  public:
124  cudaAPI(cudaAPI const&) = delete;
125  void operator=(cudaAPI const&) = delete;
126 
127  KVIKIO_EXPORT static cudaAPI& instance();
128 };
129 
137 #ifdef KVIKIO_CUDA_FOUND
138 bool is_cuda_available();
139 #else
140 constexpr bool is_cuda_available() { return false; }
141 #endif
142 
143 } // namespace kvikio
Shim layer of the cuda C-API.
Definition: cuda.hpp:89
Non-templated class to hold any callable that returns CUresult.
Definition: cuda.hpp:31
CUresult operator()(Args... args)
Invoke the container callable.
Definition: cuda.hpp:65
void reset()
Destroy the contained callable.
Definition: cuda.hpp:51
void set(Callable &&c)
Assign a callable to the object.
Definition: cuda.hpp:43
KvikIO namespace.
Definition: batch.hpp:27
constexpr bool is_cuda_available()
Check if the CUDA library is available.
Definition: cuda.hpp:140