21 #include <kvikio/shim/cuda_h_wrapper.hpp>
22 #include <kvikio/shim/utils.hpp>
42 template <
typename Callable>
43 void set(Callable&& c)
45 _callable = std::function(c);
51 void reset() { _callable.reset(); }
64 template <
typename... Args>
67 using T = std::function<CUresult(Args...)>;
68 if (!_callable.has_value()) {
69 throw std::runtime_error(
"No callable has been assigned to the wrapper yet.");
71 return std::any_cast<T&>(_callable)(args...);
77 operator bool()
const {
return _callable.has_value(); }
91 int driver_version{0};
93 decltype(cuInit)* Init{
nullptr};
94 decltype(cuMemHostAlloc)* MemHostAlloc{
nullptr};
95 decltype(cuMemFreeHost)* MemFreeHost{
nullptr};
96 decltype(cuMemcpyHtoDAsync)* MemcpyHtoDAsync{
nullptr};
97 decltype(cuMemcpyDtoHAsync)* MemcpyDtoHAsync{
nullptr};
101 decltype(cuPointerGetAttribute)* PointerGetAttribute{
nullptr};
102 decltype(cuPointerGetAttributes)* PointerGetAttributes{
nullptr};
103 decltype(cuCtxPushCurrent)* CtxPushCurrent{
nullptr};
104 decltype(cuCtxPopCurrent)* CtxPopCurrent{
nullptr};
105 decltype(cuCtxGetCurrent)* CtxGetCurrent{
nullptr};
106 decltype(cuCtxGetDevice)* CtxGetDevice{
nullptr};
107 decltype(cuMemGetAddressRange)* MemGetAddressRange{
nullptr};
108 decltype(cuGetErrorName)* GetErrorName{
nullptr};
109 decltype(cuGetErrorString)* GetErrorString{
nullptr};
110 decltype(cuDeviceGet)* DeviceGet{
nullptr};
111 decltype(cuDeviceGetCount)* DeviceGetCount{
nullptr};
112 decltype(cuDeviceGetAttribute)* DeviceGetAttribute{
nullptr};
113 decltype(cuDevicePrimaryCtxRetain)* DevicePrimaryCtxRetain{
nullptr};
114 decltype(cuDevicePrimaryCtxRelease)* DevicePrimaryCtxRelease{
nullptr};
115 decltype(cuStreamSynchronize)* StreamSynchronize{
nullptr};
116 decltype(cuStreamCreate)* StreamCreate{
nullptr};
117 decltype(cuStreamDestroy)* StreamDestroy{
nullptr};
118 decltype(cuDriverGetVersion)* DriverGetVersion{
nullptr};
125 void operator=(
cudaAPI const&) =
delete;
127 KVIKIO_EXPORT
static cudaAPI& instance();
137 #ifdef KVIKIO_CUDA_FOUND
Shim layer of the cuda C-API.
Non-templated class to hold any callable that returns CUresult.
CUresult operator()(Args... args)
Invoke the container callable.
void reset()
Destroy the contained callable.
void set(Callable &&c)
Assign a callable to the object.
constexpr bool is_cuda_available()
Check if the CUDA library is available.