cuda.hpp
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
7 #include <any>
8 #include <functional>
9 
10 #include <cuda.h>
11 #include <kvikio/shim/utils.hpp>
12 #include <stdexcept>
13 
14 namespace kvikio {
15 
16 namespace detail {
20 class AnyCallable {
21  private:
22  std::any _callable;
23 
24  public:
31  template <typename Callable>
32  void set(Callable&& c)
33  {
34  _callable = std::function(c);
35  }
36 
40  void reset() { _callable.reset(); }
41 
53  template <typename... Args>
54  CUresult operator()(Args... args)
55  {
56  using T = std::function<CUresult(Args...)>;
57  if (!_callable.has_value()) {
58  throw std::runtime_error("No callable has been assigned to the wrapper yet.");
59  }
60  return std::any_cast<T&>(_callable)(args...);
61  }
62 
66  operator bool() const { return _callable.has_value(); }
67 };
68 
69 } // namespace detail
70 
78 class cudaAPI {
79  public:
80  int driver_version{0};
81 
82  decltype(cuInit)* Init{nullptr};
83  decltype(cuMemHostAlloc)* MemHostAlloc{nullptr};
84  decltype(cuMemFreeHost)* MemFreeHost{nullptr};
85  decltype(cuMemHostRegister)* MemHostRegister{nullptr};
86  decltype(cuMemHostUnregister)* MemHostUnregister{nullptr};
87  decltype(cuMemcpyHtoDAsync)* MemcpyHtoDAsync{nullptr};
88  decltype(cuMemcpyDtoHAsync)* MemcpyDtoHAsync{nullptr};
89 
90  detail::AnyCallable MemcpyBatchAsync{};
91 
92  decltype(cuPointerGetAttribute)* PointerGetAttribute{nullptr};
93  decltype(cuPointerGetAttributes)* PointerGetAttributes{nullptr};
94  decltype(cuCtxPushCurrent)* CtxPushCurrent{nullptr};
95  decltype(cuCtxPopCurrent)* CtxPopCurrent{nullptr};
96  decltype(cuCtxGetCurrent)* CtxGetCurrent{nullptr};
97  decltype(cuCtxGetDevice)* CtxGetDevice{nullptr};
98  decltype(cuMemGetAddressRange)* MemGetAddressRange{nullptr};
99  decltype(cuGetErrorName)* GetErrorName{nullptr};
100  decltype(cuGetErrorString)* GetErrorString{nullptr};
101  decltype(cuDeviceGet)* DeviceGet{nullptr};
102  decltype(cuDeviceGetCount)* DeviceGetCount{nullptr};
103  decltype(cuDeviceGetAttribute)* DeviceGetAttribute{nullptr};
104  decltype(cuDevicePrimaryCtxRetain)* DevicePrimaryCtxRetain{nullptr};
105  decltype(cuDevicePrimaryCtxRelease)* DevicePrimaryCtxRelease{nullptr};
106  decltype(cuStreamSynchronize)* StreamSynchronize{nullptr};
107  decltype(cuStreamCreate)* StreamCreate{nullptr};
108  decltype(cuStreamDestroy)* StreamDestroy{nullptr};
109  decltype(cuDriverGetVersion)* DriverGetVersion{nullptr};
110 
111  private:
112  cudaAPI();
113 
114  public:
115  cudaAPI(cudaAPI const&) = delete;
116  void operator=(cudaAPI const&) = delete;
117 
118  KVIKIO_EXPORT static cudaAPI& instance();
119 };
120 
129 
130 } // namespace kvikio
Shim layer of the cuda C-API.
Definition: cuda.hpp:78
Non-templated class to hold any callable that returns CUresult.
Definition: cuda.hpp:20
CUresult operator()(Args... args)
Invoke the container callable.
Definition: cuda.hpp:54
void reset()
Destroy the contained callable.
Definition: cuda.hpp:40
void set(Callable &&c)
Assign a callable to the object.
Definition: cuda.hpp:32
KvikIO namespace.
Definition: batch.hpp:16
bool is_cuda_available()
Check if the CUDA library is available.