Source code for trtutils.core._cuda
# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# ruff: noqa: TRY004
# mypy: disable-error-code="import-untyped"
from __future__ import annotations
from typing import TypeVar
from trtutils._log import LOG
from trtutils.compat._libs import cuda, cudart
def check_cuda_err(err: cuda.CUresult | cudart.cudaError_t) -> None:
"""
Check if a CUDA error occurred and raise an exception if so.
Parameters
----------
err : cuda.CUresult | cudart.cudaError_t
The CUDA error to check.
Raises
------
RuntimeError
If a CUDA or CUDA Runtime error occurred.
"""
if isinstance(err, cuda.CUresult):
if err != cuda.CUresult.CUDA_SUCCESS:
err_msg = f"Cuda Error: {err} -> "
err_msg += f"{cuda_call(cuda.cuGetErrorName(err))} -> "
err_msg += f"{cuda_call(cuda.cuGetErrorString(err))}"
raise RuntimeError(err_msg)
elif isinstance(err, cudart.cudaError_t):
if err != cudart.cudaError_t.cudaSuccess:
err_msg = f"Cuda Runtime Error: {err} -> "
err_msg += f"{cuda_call(cudart.cudaGetErrorName(err))} -> "
err_msg += f"{cuda_call(cudart.cudaGetErrorString(err))}"
raise RuntimeError(err_msg)
else:
err_msg = f"Unknown error type: {err}"
raise RuntimeError(err_msg)
T = TypeVar("T")
[docs]
def cuda_call(call: tuple[cuda.CUresult | cudart.cudaError_t, T]) -> T:
"""
Call a CUDA function and check for errors.
Parameters
----------
call : tuple[cuda.CUresult | cudart.cudaError_t, T]
The CUDA function to call and its arguments.
Returns
-------
T
The result of the CUDA function call.
"""
err, res = call[0], call[1:]
check_cuda_err(err)
if len(res) == 1:
return res[0]
return res
[docs]
def init_cuda() -> None:
"""Initialize CUDA."""
cuda_call(cuda.cuInit(0))
device_count = cuda_call(cuda.cuDeviceGetCount())
LOG.info(f"Number of CUDA devices: {device_count}")