# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations
import contextlib
from typing import TypeVar
import numpy as np
with contextlib.suppress(Exception):
try:
import cuda.bindings.driver as cuda
import cuda.bindings.nvrtc as nvrtc
except (ImportError, ModuleNotFoundError):
from cuda import cuda, nvrtc
from trtutils._log import LOG
from ._cuda import cuda_call
from ._lock import MEM_ALLOC_LOCK, NVRTC_LOCK
def check_nvrtc_err(err: nvrtc.nvrtcResult) -> None:
"""
Check if a NVRTC error occured and raise an exception if so.
Parameters
----------
err : nvrtc.nvrtcResult
The NVRTC return code to check.
Raises
------
RuntimeError
If a NVRTC error occured.
"""
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
err_msg = f"NVRTC Error: {err} -> "
err_msg += f"{nvrtc_call(nvrtc.nvrtcGetErrorString(err))}"
raise RuntimeError(err_msg)
T = TypeVar("T")
[docs]
def nvrtc_call(call: tuple[nvrtc.nvrtcResult, T]) -> T:
"""
Call a NVRTC function and check for errors.
Parameters
----------
call : tuple[cuda.CUresult | cudart.cudaError_t, T]
The NVRTC function to call and its arguments.
Returns
-------
T
The result of the NVRTC function call.
"""
err, res = call[0], call[1:]
check_nvrtc_err(err)
if len(res) == 1:
return res[0]
return res
[docs]
def compile_kernel(
kernel: str,
name: str,
opts: list[str] | None = None,
*,
verbose: bool | None = None,
) -> np.char.chararray:
"""
Compile a CUDA kernel into PTX using NVRTC.
Parameters
----------
kernel : str
The kernel definition in CUDA.
name : str
The name of the kernel in the definition.
opts : list[str]
The optional additional arguments to pass to NVRTC during
the compilation of the kernel.
verbose : bool, optional
Whether or not to output additional information
to stdout. If not provided, will default to overall
engines verbose setting.
Returns
-------
tuple[np.char.chararray, str]
The compiled PTX kernel and the kernel name.
Raises
------
RuntimeError
If the version of cuda-python installed does not match the version of CUDA installed.
"""
kernel_bytes = kernel.encode()
kernel_name_bytes = f"{name}.cu".encode()
if verbose:
LOG.debug(f"Compiling kernel: {name}")
# compile the kernel
try:
with MEM_ALLOC_LOCK, NVRTC_LOCK:
prog = nvrtc_call(
nvrtc.nvrtcCreateProgram(kernel_bytes, kernel_name_bytes, 0, [], []),
)
except RuntimeError as err:
if "Failed to dlopen libnvrtc" in str(err):
err_msg = str(err)
err_msg += " Ensure the version of cuda-python installed matches the version of CUDA installed."
raise RuntimeError(err_msg) from err
raise
opts = [] if opts is None else opts
with MEM_ALLOC_LOCK, NVRTC_LOCK:
nvrtc_call(nvrtc.nvrtcCompileProgram(prog, len(opts), opts))
# generate the actual kernel ptx
ptx_size = nvrtc_call(nvrtc.nvrtcGetPTXSize(prog))
ptx_buffer = b"\0" * ptx_size
nvrtc_call(nvrtc.nvrtcGetPTX(prog, ptx_buffer))
return np.char.array(ptx_buffer)
[docs]
def load_kernel(
kernel_ptx: np.char.chararray,
name: str,
*,
verbose: bool | None = None,
) -> tuple[cuda.CUmodule, cuda.CUkernel]:
"""
Load a kernel from a PTX definition.
Parameters
----------
kernel_ptx: np.char.chararray
The PTX generated by NVRTC, use the compile_kernel function.
name: str
The name of the kernel inside the PTX definiton.
verbose : bool, optional
Whether or not to output additional information
to stdout. If not provided, will default to overall
engines verbose setting.
Returns
-------
tuple[cuda.CUmodule, cuda.CUkernel]
The CUDA module and kernel
"""
if verbose:
LOG.debug(f"Loading kernel: {name} from PTX")
module: cuda.CUmodule = cuda_call(cuda.cuModuleLoadData(kernel_ptx.ctypes.data))
kernel: cuda.CUkernel = cuda_call(
cuda.cuModuleGetFunction(module, name.encode()),
)
return module, kernel
[docs]
def compile_and_load_kernel(
kernel_code: str,
name: str,
opts: list[str] | None = None,
*,
verbose: bool | None = None,
) -> tuple[cuda.CUmodule, cuda.CUkernel]:
"""
Compile and load a kernel from source definiton.
Parameters
----------
kernel_code : str
The code definition of the kernel.
name : str
The name of the kernel.
opts : list[str]
The optional additional arguments to pass to NVRTC during
the compilation of the kernel.
verbose : bool, optional
Whether or not to output additional information
to stdout. If not provided, will default to overall
engines verbose setting.
Returns
-------
tuple[cuda.CUmodule, cuda.CUkernel]
The CUDA module and kernel
"""
ptx = compile_kernel(kernel_code, name, opts, verbose=verbose)
module, kernel = load_kernel(ptx, name, verbose=verbose)
return module, kernel