# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)## MIT License# ruff: noqa: TRY004# mypy: disable-error-code="import-untyped"from__future__importannotationsimportcontextlibfromtypingimportTypeVarwithcontextlib.suppress(Exception):try:importcuda.bindings.driverascudaimportcuda.bindings.runtimeascudartexcept(ImportError,ModuleNotFoundError):fromcudaimportcuda,cudartfromtrtutils._logimportLOGdefcheck_cuda_err(err:cuda.CUresult|cudart.cudaError_t)->None:""" Check if a CUDA error occurred and raise an exception if so. Parameters ---------- err : cuda.CUresult | cudart.cudaError_t The CUDA error to check. Raises ------ RuntimeError If a CUDA or CUDA Runtime error occurred. """ifisinstance(err,cuda.CUresult):iferr!=cuda.CUresult.CUDA_SUCCESS:err_msg=f"Cuda Error: {err} -> "err_msg+=f"{cuda_call(cuda.cuGetErrorName(err))} -> "err_msg+=f"{cuda_call(cuda.cuGetErrorString(err))}"raiseRuntimeError(err_msg)elifisinstance(err,cudart.cudaError_t):iferr!=cudart.cudaError_t.cudaSuccess:err_msg=f"Cuda Runtime Error: {err} -> "err_msg+=f"{cuda_call(cudart.cudaGetErrorName(err))} -> "err_msg+=f"{cuda_call(cudart.cudaGetErrorString(err))}"raiseRuntimeError(err_msg)else:err_msg=f"Unknown error type: {err}"raiseRuntimeError(err_msg)T=TypeVar("T")
[docs]defcuda_call(call:tuple[cuda.CUresult|cudart.cudaError_t,T])->T:""" Call a CUDA function and check for errors. Parameters ---------- call : tuple[cuda.CUresult | cudart.cudaError_t, T] The CUDA function to call and its arguments. Returns ------- T The result of the CUDA function call. """err,res=call[0],call[1:]check_cuda_err(err)iflen(res)==1:returnres[0]returnres
[docs]definit_cuda()->None:"""Initialize CUDA."""cuda_call(cuda.cuInit(0))device_count=cuda_call(cuda.cuDeviceGetCount())LOG.info(f"Number of CUDA devices: {device_count}")