20 #include <cuda_runtime_api.h>
23 #include <type_traits>
67 cuda_error(std::string
const& message, cudaError_t
const& error)
68 : std::runtime_error(message),
_cudaError(error)
91 #define STRINGIFY_DETAIL(x) #x
92 #define CUDF_STRINGIFY(x) STRINGIFY_DETAIL(x)
123 #define CUDF_EXPECTS(...) \
124 GET_CUDF_EXPECTS_MACRO(__VA_ARGS__, CUDF_EXPECTS_3, CUDF_EXPECTS_2) \
129 #define GET_CUDF_EXPECTS_MACRO(_1, _2, _3, NAME, ...) NAME
131 #define CUDF_EXPECTS_3(_condition, _reason, _exception_type) \
133 static_assert(std::is_base_of_v<std::exception, _exception_type>); \
134 (_condition) ? static_cast<void>(0) \
135 : throw _exception_type \
136 {"CUDF failure at: " __FILE__ ":" CUDF_STRINGIFY(__LINE__) ": " _reason}; \
139 #define CUDF_EXPECTS_2(_condition, _reason) CUDF_EXPECTS_3(_condition, _reason, cudf::logic_error)
162 #define CUDF_FAIL(...) \
163 GET_CUDF_FAIL_MACRO(__VA_ARGS__, CUDF_FAIL_2, CUDF_FAIL_1) \
168 #define GET_CUDF_FAIL_MACRO(_1, _2, NAME, ...) NAME
170 #define CUDF_FAIL_2(_what, _exception_type) \
172 throw _exception_type { "CUDF failure at:" __FILE__ ":" CUDF_STRINGIFY(__LINE__) ": " _what }
174 #define CUDF_FAIL_1(_what) CUDF_FAIL_2(_what, cudf::logic_error)
181 inline void throw_cuda_error(cudaError_t error,
const char* file,
unsigned int line)
186 auto const last = cudaFree(0);
187 auto const msg = std::string{
"CUDA error encountered at: " + std::string{file} +
":" +
188 std::to_string(line) +
": " + std::to_string(error) +
" " +
189 cudaGetErrorName(error) +
" " + cudaGetErrorString(error)};
192 if (error == last && last == cudaDeviceSynchronize()) {
193 throw fatal_cuda_error{
"Fatal " + msg, error};
195 throw cuda_error{msg, error};
209 #define CUDF_CUDA_TRY(call) \
211 cudaError_t const status = (call); \
212 if (cudaSuccess != status) { cudf::detail::throw_cuda_error(status, __FILE__, __LINE__); } \
229 #define CUDF_CHECK_CUDA(stream) \
231 CUDF_CUDA_TRY(cudaStreamSynchronize(stream)); \
232 CUDF_CUDA_TRY(cudaPeekAtLastError()); \
235 #define CUDF_CHECK_CUDA(stream) CUDF_CUDA_TRY(cudaPeekAtLastError());