Source code for trtutils.core._nvrtc

# Copyright (c) 2024-2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations

import contextlib
import os
import shutil
from functools import lru_cache
from pathlib import Path
from typing import TypeVar

import numpy as np

with contextlib.suppress(Exception):
    from trtutils.compat._libs import cuda, nvrtc

from trtutils._log import LOG

from ._cuda import cuda_call


@lru_cache(maxsize=1)
def find_cuda_include_dir() -> Path | None:
    """
    Find the CUDA include directory for NVRTC compilation.

    Searches in the following order:
    1. CUDA_HOME environment variable
    2. CUDA_PATH environment variable
    3. Path derived from nvcc location
    4. Common default paths (/usr/local/cuda, etc.)

    Returns
    -------
    Path | None
        The path to the CUDA include directory, or None if not found.

    """
    for env_var in ("CUDA_HOME", "CUDA_PATH"):
        cuda_path = os.environ.get(env_var)
        if cuda_path:
            include_path = Path(cuda_path) / "include"
            if include_path.exists():
                LOG.debug(f"Found CUDA include path from {env_var}: {include_path}")
                return include_path

    nvcc_path = shutil.which("nvcc")
    if nvcc_path:
        cuda_root = Path(nvcc_path).parent.parent
        include_path = cuda_root / "include"
        if include_path.exists():
            LOG.debug(f"Found CUDA include path from nvcc: {include_path}")
            return include_path

    common_paths = [
        Path("/usr/local/cuda/include"),
        Path("/usr/local/cuda-13/include"),
        Path("/usr/local/cuda-12/include"),
        Path("/usr/local/cuda-11/include"),
        Path("/opt/cuda/include"),
    ]
    for include_path in common_paths:
        if include_path.exists():
            LOG.debug(f"Found CUDA include path at default location: {include_path}")
            return include_path

    LOG.warning("Could not find CUDA include path for NVRTC compilation")
    return None


def _get_default_nvrtc_opts() -> list[bytes]:
    opts: list[bytes] = []

    include_path = find_cuda_include_dir()
    if include_path:
        opts.append(f"-I{include_path}".encode())

    return opts


def check_nvrtc_err(err: nvrtc.nvrtcResult) -> None:
    """
    Check if a NVRTC error occured and raise an exception if so.

    Parameters
    ----------
    err : nvrtc.nvrtcResult
        The NVRTC return code to check.

    Raises
    ------
    RuntimeError
        If a NVRTC error occured.

    """
    if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
        err_msg = f"NVRTC Error: {err} -> "
        err_msg += f"{nvrtc_call(nvrtc.nvrtcGetErrorString(err))}"
        raise RuntimeError(err_msg)


T = TypeVar("T")


[docs] def nvrtc_call(call: tuple[nvrtc.nvrtcResult, T]) -> T: """ Call a NVRTC function and check for errors. Parameters ---------- call : tuple[cuda.CUresult | cudart.cudaError_t, T] The NVRTC function to call and its arguments. Returns ------- T The result of the NVRTC function call. """ err, res = call[0], call[1:] check_nvrtc_err(err) if len(res) == 1: return res[0] return res
[docs] def compile_kernel( kernel: str, name: str, opts: list[str] | None = None, *, verbose: bool | None = None, ) -> np.char.chararray: """ Compile a CUDA kernel into PTX using NVRTC. Parameters ---------- kernel : str The kernel definition in CUDA. name : str The name of the kernel in the definition. opts : list[str] The optional additional arguments to pass to NVRTC during the compilation of the kernel. verbose : bool, optional Whether or not to output additional information to stdout. If not provided, will default to overall engines verbose setting. Returns ------- tuple[np.char.chararray, str] The compiled PTX kernel and the kernel name. Raises ------ RuntimeError If the version of cuda-python installed does not match the version of CUDA installed. """ kernel_bytes = kernel.encode() kernel_name_bytes = f"{name}.cu".encode() if verbose: LOG.debug(f"Compiling kernel: {name}") # compile the kernel try: prog = nvrtc_call( nvrtc.nvrtcCreateProgram(kernel_bytes, kernel_name_bytes, 0, [], []), ) except RuntimeError as err: if "Failed to dlopen libnvrtc" in str(err): err_msg = str(err) err_msg += ( " Ensure the version of cuda-python installed matches the version of CUDA installed." ) raise RuntimeError(err_msg) from err raise opts_with_cuda_include_dir = _get_default_nvrtc_opts() if opts is not None: for opt in opts: opts_with_cuda_include_dir.append(opt.encode()) nvrtc_call( nvrtc.nvrtcCompileProgram(prog, len(opts_with_cuda_include_dir), opts_with_cuda_include_dir) ) # generate the actual kernel ptx ptx_size = nvrtc_call(nvrtc.nvrtcGetPTXSize(prog)) ptx_buffer = b"\0" * ptx_size nvrtc_call(nvrtc.nvrtcGetPTX(prog, ptx_buffer)) return np.char.array(ptx_buffer)
[docs] def load_kernel( kernel_ptx: np.char.chararray, name: str, *, verbose: bool | None = None, ) -> tuple[cuda.CUmodule, cuda.CUkernel]: """ Load a kernel from a PTX definition. Parameters ---------- kernel_ptx: np.char.chararray The PTX generated by NVRTC, use the compile_kernel function. name: str The name of the kernel inside the PTX definiton. verbose : bool, optional Whether or not to output additional information to stdout. If not provided, will default to overall engines verbose setting. Returns ------- tuple[cuda.CUmodule, cuda.CUkernel] The CUDA module and kernel """ if verbose: LOG.debug(f"Loading kernel: {name} from PTX") module: cuda.CUmodule = cuda_call(cuda.cuModuleLoadData(kernel_ptx.ctypes.data)) kernel: cuda.CUkernel = cuda_call( cuda.cuModuleGetFunction(module, name.encode()), ) return module, kernel
[docs] def compile_and_load_kernel( kernel_code: str, name: str, opts: list[str] | None = None, *, verbose: bool | None = None, ) -> tuple[cuda.CUmodule, cuda.CUkernel]: """ Compile and load a kernel from source definiton. Parameters ---------- kernel_code : str The code definition of the kernel. name : str The name of the kernel. opts : list[str] The optional additional arguments to pass to NVRTC during the compilation of the kernel. verbose : bool, optional Whether or not to output additional information to stdout. If not provided, will default to overall engines verbose setting. Returns ------- tuple[cuda.CUmodule, cuda.CUkernel] The CUDA module and kernel """ ptx = compile_kernel(kernel_code, name, opts, verbose=verbose) module, kernel = load_kernel(ptx, name, verbose=verbose) return module, kernel