All Classes Namespaces Functions Enumerations Enumerator Modules Pages
cuda.hpp
1 /*
2  * Copyright (c) 2022-2025, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include <kvikio/shim/cuda_h_wrapper.hpp>
19 #include <kvikio/shim/utils.hpp>
20 
21 namespace kvikio {
22 
30 class cudaAPI {
31  public:
32  decltype(cuInit)* Init{nullptr};
33  decltype(cuMemHostAlloc)* MemHostAlloc{nullptr};
34  decltype(cuMemFreeHost)* MemFreeHost{nullptr};
35  decltype(cuMemcpyHtoDAsync)* MemcpyHtoDAsync{nullptr};
36  decltype(cuMemcpyDtoHAsync)* MemcpyDtoHAsync{nullptr};
37  decltype(cuPointerGetAttribute)* PointerGetAttribute{nullptr};
38  decltype(cuPointerGetAttributes)* PointerGetAttributes{nullptr};
39  decltype(cuCtxPushCurrent)* CtxPushCurrent{nullptr};
40  decltype(cuCtxPopCurrent)* CtxPopCurrent{nullptr};
41  decltype(cuCtxGetCurrent)* CtxGetCurrent{nullptr};
42  decltype(cuMemGetAddressRange)* MemGetAddressRange{nullptr};
43  decltype(cuGetErrorName)* GetErrorName{nullptr};
44  decltype(cuGetErrorString)* GetErrorString{nullptr};
45  decltype(cuDeviceGet)* DeviceGet{nullptr};
46  decltype(cuDevicePrimaryCtxRetain)* DevicePrimaryCtxRetain{nullptr};
47  decltype(cuDevicePrimaryCtxRelease)* DevicePrimaryCtxRelease{nullptr};
48  decltype(cuStreamSynchronize)* StreamSynchronize{nullptr};
49  decltype(cuStreamCreate)* StreamCreate{nullptr};
50  decltype(cuStreamDestroy)* StreamDestroy{nullptr};
51 
52  private:
53  cudaAPI();
54 
55  public:
56  cudaAPI(cudaAPI const&) = delete;
57  void operator=(cudaAPI const&) = delete;
58 
59  KVIKIO_EXPORT static cudaAPI& instance();
60 };
61 
69 #ifdef KVIKIO_CUDA_FOUND
70 bool is_cuda_available();
71 #else
72 constexpr bool is_cuda_available() { return false; }
73 #endif
74 
75 } // namespace kvikio
Shim layer of the cuda C-API.
Definition: cuda.hpp:30
KvikIO namespace.
Definition: batch.hpp:27
constexpr bool is_cuda_available()
Check if the CUDA library is available.
Definition: cuda.hpp:72