cuda.hpp
1 /*
2  * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION.
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 #pragma once
6 
7 #include <any>
8 #include <functional>
9 
10 #include <cuda.h>
11 #include <kvikio/shim/utils.hpp>
12 #include <stdexcept>
13 
14 namespace kvikio {
15 
16 namespace detail {
20 class AnyCallable {
21  private:
22  std::any _callable;
23 
24  public:
31  template <typename Callable>
32  void set(Callable&& c)
33  {
34  _callable = std::function(c);
35  }
36 
40  void reset() { _callable.reset(); }
41 
53  template <typename... Args>
54  CUresult operator()(Args... args)
55  {
56  using T = std::function<CUresult(Args...)>;
57  if (!_callable.has_value()) {
58  throw std::runtime_error("No callable has been assigned to the wrapper yet.");
59  }
60  return std::any_cast<T&>(_callable)(args...);
61  }
62 
66  operator bool() const { return _callable.has_value(); }
67 };
68 
69 } // namespace detail
70 
78 class cudaAPI {
79  public:
80  int driver_version{0};
81 
82  decltype(cuInit)* Init{nullptr};
83  decltype(cuMemHostAlloc)* MemHostAlloc{nullptr};
84  decltype(cuMemFreeHost)* MemFreeHost{nullptr};
85  decltype(cuMemcpyHtoDAsync)* MemcpyHtoDAsync{nullptr};
86  decltype(cuMemcpyDtoHAsync)* MemcpyDtoHAsync{nullptr};
87 
88  detail::AnyCallable MemcpyBatchAsync{};
89 
90  decltype(cuPointerGetAttribute)* PointerGetAttribute{nullptr};
91  decltype(cuPointerGetAttributes)* PointerGetAttributes{nullptr};
92  decltype(cuCtxPushCurrent)* CtxPushCurrent{nullptr};
93  decltype(cuCtxPopCurrent)* CtxPopCurrent{nullptr};
94  decltype(cuCtxGetCurrent)* CtxGetCurrent{nullptr};
95  decltype(cuCtxGetDevice)* CtxGetDevice{nullptr};
96  decltype(cuMemGetAddressRange)* MemGetAddressRange{nullptr};
97  decltype(cuGetErrorName)* GetErrorName{nullptr};
98  decltype(cuGetErrorString)* GetErrorString{nullptr};
99  decltype(cuDeviceGet)* DeviceGet{nullptr};
100  decltype(cuDeviceGetCount)* DeviceGetCount{nullptr};
101  decltype(cuDeviceGetAttribute)* DeviceGetAttribute{nullptr};
102  decltype(cuDevicePrimaryCtxRetain)* DevicePrimaryCtxRetain{nullptr};
103  decltype(cuDevicePrimaryCtxRelease)* DevicePrimaryCtxRelease{nullptr};
104  decltype(cuStreamSynchronize)* StreamSynchronize{nullptr};
105  decltype(cuStreamCreate)* StreamCreate{nullptr};
106  decltype(cuStreamDestroy)* StreamDestroy{nullptr};
107  decltype(cuDriverGetVersion)* DriverGetVersion{nullptr};
108 
109  private:
110  cudaAPI();
111 
112  public:
113  cudaAPI(cudaAPI const&) = delete;
114  void operator=(cudaAPI const&) = delete;
115 
116  KVIKIO_EXPORT static cudaAPI& instance();
117 };
118 
127 
128 } // namespace kvikio
Shim layer of the cuda C-API.
Definition: cuda.hpp:78
Non-templated class to hold any callable that returns CUresult.
Definition: cuda.hpp:20
CUresult operator()(Args... args)
Invoke the container callable.
Definition: cuda.hpp:54
void reset()
Destroy the contained callable.
Definition: cuda.hpp:40
void set(Callable &&c)
Assign a callable to the object.
Definition: cuda.hpp:32
KvikIO namespace.
Definition: batch.hpp:16
bool is_cuda_available()
Check if the CUDA library is available.