Source code for trtutils._engine

# Copyright (c) 2024-2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations

import contextlib
import time
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import nvtx

from ._flags import FLAGS
from ._log import LOG
from .core._graph import CUDAGraph
from .core._interface import TRTEngineInterface
from .core._memory import (
    memcpy_device_to_host,
    memcpy_device_to_host_async,
    memcpy_host_to_device,
    memcpy_host_to_device_async,
)
from .core._stream import stream_synchronize

if TYPE_CHECKING:
    from typing import ClassVar

    from typing_extensions import Self

    from trtutils.compat._libs import cuda


[docs] class TRTEngine(TRTEngineInterface): """ Implements a generic interface for TensorRT engines. It is thread and process safe to create multiple TRTEngines. It is valid to create a TRTEngine in one thread and use in another. Each TRTEngine has its own CUDA context and there is no safeguards implemented in the class for datarace conditions. As such, a single TRTEngine should not be used in multiple threads or processes. """ _backends: ClassVar[set[str]] = {"auto", "async_v3", "async_v2"} def __init__( self: Self, engine_path: Path | str, warmup_iterations: int = 5, backend: str = "auto", stream: cuda.cudaStream_t | None = None, dla_core: int | None = None, device: int | None = None, *, warmup: bool | None = None, pagelocked_mem: bool | None = None, unified_mem: bool | None = None, cuda_graph: bool | None = None, no_warn: bool | None = None, verbose: bool | None = None, ) -> None: """ Load the TensorRT engine from a file. Parameters ---------- engine_path : Path | str The path to the serialized engine file. warmup : bool, optional Whether to do warmup iterations, by default None If None, warmup will be set to False backend : str, optional What version of backend execution to use. By default 'auto', which will use v3 if available otherwise v2. Options are: ['auto', 'async_v3', 'async_v2] stream : cuda.cudaStream_t, optional The CUDA stream to use for this engine. By default None, will allocate a new stream. dla_core : int, optional The DLA core to assign DLA layers of the engine to. Default is None. If None, any DLA layers will be assigned to DLA core 0. device : int, optional The CUDA device index to use for this engine. Default is None, which uses the current device. warmup_iterations : int, optional The number of warmup iterations to do, by default 5 pagelocked_mem : bool, optional Whether or not to use pagelocked memory for host allocations. By default None, which means pagelocked memory will be used. unified_mem : bool, optional Whether or not the system has unified memory. If True, use cudaHostAllocMapped to take advantage of unified memory. By default None, which will automatically determine what to use. cuda_graph : bool, optional Whether to enable CUDA graph capture for optimized execution. By default True. Only effective when using async_v3 backend. no_warn : bool, optional If True, suppresses warnings from TensorRT during engine deserialization. Default is None, which means warnings will be shown. verbose : bool, optional Whether or not to give additional information over stdout. Raises ------ ValueError If the backend is not valid. """ self._name = Path(engine_path).stem if FLAGS.NVTX_ENABLED: nvtx.push_range(f"engine::init [{self.name}]") super().__init__( engine_path, stream=stream, dla_core=dla_core, device=device, pagelocked_mem=pagelocked_mem, unified_mem=unified_mem, no_warn=no_warn, verbose=verbose, ) self._nvtx_tags.update( { "graph_capture": f"engine::graph_capture [{self.name}]", "execute": f"engine::execute [{self.name}]", "graph_exec": f"engine::graph_exec [{self.name}]", "direct_exec": f"engine::direct_exec [{self.name}]", "raw_exec": f"engine::raw_exec [{self.name}]", } ) # solve for execution method # only care about v2 or v3 async if backend not in TRTEngine._backends: err_msg = f"Invalid backend {backend}, options are: {TRTEngine._backends}" if FLAGS.NVTX_ENABLED: nvtx.pop_range() # init raise ValueError(err_msg) self._async_v3 = FLAGS.EXEC_ASYNC_V3 and (backend == "async_v3" or backend == "auto") # CUDA graph support # needs to happen before input/output bindings are set since # CUDA graph is used in those calls self._cuda_graph_enabled: bool = ( cuda_graph if cuda_graph is not None else True ) and self._async_v3 self._cuda_graph: CUDAGraph | None = None self._capturing_graph: bool = False # Guard against capture recursion if self._cuda_graph_enabled: self._cuda_graph = CUDAGraph(self._stream) # if using v3 # 1.) need to do set_input_shape for all input bindings # 2.) need to do set_tensor_address for all input/output bindings if self._async_v3: self._set_input_bindings() self._set_output_bindings() # if using the v3 backend also need to track if we are pointing to the 'built-in' tensors # only applies to the inputs self._using_engine_tensors: bool = True # store timing variable for sleep call before stream_sync self._sync_t: float = 0.0 # store verbose info self._verbose = verbose if verbose is not None else False if FLAGS.NVTX_ENABLED: nvtx.push_range(self._nvtx_tags["warmup"]) self._warmup = warmup if self._warmup: self.warmup(warmup_iterations, verbose=self._verbose) if FLAGS.NVTX_ENABLED: nvtx.pop_range() # warmup nvtx.pop_range() # init LOG.debug(f"Created TRTEngine: {self.name}") def _set_input_bindings(self: Self) -> None: for i_binding in self._inputs: self._context.set_input_shape(i_binding.name, i_binding.shape) self._context.set_tensor_address(i_binding.name, i_binding.allocation) # CUDA graph is invalid if using new bindings if self._cuda_graph and self._cuda_graph.is_captured: self._cuda_graph.invalidate() def _set_output_bindings(self: Self) -> None: for o_binding in self._outputs: self._context.set_tensor_address(o_binding.name, o_binding.allocation) # CUDA graph is invalid if using new bindings if self._cuda_graph and self._cuda_graph.is_captured: self._cuda_graph.invalidate() def _capture_cuda_graph(self: Self) -> None: if FLAGS.NVTX_ENABLED: nvtx.push_range(self._nvtx_tags["graph_capture"]) # Prevent recursion: warmup() -> mock_execute() -> execute() -> _capture_cuda_graph() if self._capturing_graph: if FLAGS.NVTX_ENABLED: nvtx.pop_range() # graph_capture return if self._cuda_graph is None: err_msg = f"CUDA graph is not enabled in engine: {self._name}" if FLAGS.NVTX_ENABLED: nvtx.pop_range() # graph_capture raise RuntimeError(err_msg) self._capturing_graph = True capture_error: RuntimeError | None = None try: # at least one execution required prior to graph capture # simply use one warmup iteration if warmup didnt get run if not self._warmup: try: self.warmup(1, verbose=self._verbose) except RuntimeError as e: # Warmup can fail due to multi-threaded capture conflicts if self._cuda_graph is not None: self._cuda_graph.invalidate() self._cuda_graph = None err_msg = ( f"CUDA graph capture failed for engine '{self._name}' during warmup: {e}\n" "This can happen when multiple engines attempt graph capture simultaneously.\n" "To resolve: use cuda_graph=False, or ensure engines are created sequentially, " "or use warmup=True to capture graphs at initialization time." ) capture_error = RuntimeError(err_msg) capture_error.__cause__ = e return # CUDAGraph handles capture with a context manager with self._cuda_graph: # manually run execute_async_v3 instead of execute since # we only want the TRT engine self._context.execute_async_v3(self._stream) # Check if capture succeeded if not self._cuda_graph.is_captured: self._cuda_graph = None err_msg = ( f"CUDA graph capture failed for engine '{self._name}'.\n" "The engine may not support CUDA graph capture.\n" "To resolve: use cuda_graph=False to disable CUDA graphs for this engine." ) capture_error = RuntimeError(err_msg) finally: self._capturing_graph = False if capture_error is not None: if FLAGS.NVTX_ENABLED: nvtx.pop_range() # graph_capture raise capture_error if FLAGS.NVTX_ENABLED: nvtx.pop_range() # graph_capture def __del__(self: Self) -> None: with contextlib.suppress(AttributeError): if self._cuda_graph is not None: self._cuda_graph.invalidate() super().__del__()
[docs] def execute( self: Self, data: list[np.ndarray], *, no_copy: bool | None = None, verbose: bool | None = None, debug: bool | None = None, ) -> list[np.ndarray]: """ Execute the network with the given inputs. Parameters ---------- data : list[np.ndarray] The inputs to the network. no_copy : bool, optional If True, the outputs will not be copied out from the cuda allocated host memory. Instead, the host memory will be returned directly. This memory WILL BE OVERWRITTEN INPLACE by future inferences. verbose : bool, optional Whether or not to output additional information to stdout. If not provided, will default to overall engines verbose setting. debug : bool, optional Enable intermediate stream synchronize for debugging. Returns ------- list[np.ndarray] The outputs of the network. Notes ----- This method always synchronizes the stream before returning, ensuring outputs are ready to read on the host. """ verbose = verbose if verbose is not None else self._verbose if verbose: LOG.info(f"{time.perf_counter()} {self.name} Dispatch: BEGIN") if FLAGS.NVTX_ENABLED: nvtx.push_range(self._nvtx_tags["execute"]) with self._device_guard: # reset the input bindings if direct_exec or raw_exec were used if not self._using_engine_tensors: self._set_input_bindings() self._using_engine_tensors = True # copy inputs if self._pagelocked_mem and self._unified_mem: for i_idx in range(len(self._inputs)): np.copyto(self._inputs[i_idx].host_allocation, data[i_idx]) elif self._pagelocked_mem: for i_idx in range(len(self._inputs)): memcpy_host_to_device_async( self._inputs[i_idx].allocation, data[i_idx], self._stream, ) else: for i_idx in range(len(self._inputs)): memcpy_host_to_device( self._inputs[i_idx].allocation, data[i_idx], ) if debug: stream_synchronize(self._stream) # execute if self._cuda_graph: if self._cuda_graph.is_captured: # uses already captured graph to handle execution self._cuda_graph.launch() elif not self._capturing_graph: # Capture the graph (warmup inside will use random data) self._capture_cuda_graph() # After capture, re-copy user's input (warmup overwrote it) and launch if self._cuda_graph is not None and self._cuda_graph.is_captured: if self._pagelocked_mem and self._unified_mem: for i_idx in range(len(self._inputs)): np.copyto(self._inputs[i_idx].host_allocation, data[i_idx]) elif self._pagelocked_mem: for i_idx in range(len(self._inputs)): memcpy_host_to_device_async( self._inputs[i_idx].allocation, data[i_idx], self._stream, ) else: for i_idx in range(len(self._inputs)): memcpy_host_to_device( self._inputs[i_idx].allocation, data[i_idx], ) self._cuda_graph.launch() else: # Currently capturing graph, use direct execution for warmup self._context.execute_async_v3(self._stream) # base execution cases elif self._async_v3: self._context.execute_async_v3(self._stream) else: self._context.execute_async_v2(self._allocations, self._stream) if debug: stream_synchronize(self._stream) # copy outputs if self._unified_mem and self._pagelocked_mem: pass elif self._pagelocked_mem: for o_idx in range(len(self._outputs)): memcpy_device_to_host_async( self._outputs[o_idx].host_allocation, self._outputs[o_idx].allocation, self._stream, ) else: for o_idx in range(len(self._outputs)): memcpy_device_to_host( self._outputs[o_idx].host_allocation, self._outputs[o_idx].allocation, ) # make sure all operations are complete # Skip sync when warming up for graph capture to avoid conflicts # with cudaStreamCaptureModeGlobal in multi-threaded scenarios if not self._capturing_graph: stream_synchronize(self._stream) if verbose: LOG.info(f"{time.perf_counter()} {self.name} Dispatch: END") # return the results if no_copy: outputs = [o.host_allocation for o in self._outputs] else: outputs = [o.host_allocation.copy() for o in self._outputs] if FLAGS.NVTX_ENABLED: nvtx.pop_range() return outputs
[docs] def graph_exec( self: Self, *, debug: bool | None = None, ) -> None: """ Launch the captured CUDA graph. This method only launches the graph - it does not handle input/output memory transfers or graph capture. The graph must already be captured (via warmup or prior execute() calls). This method does NOT synchronize the stream by default, allowing the graph to be embedded in a larger pipeline. Use debug=True to force synchronization. Parameters ---------- debug : bool, optional If True, synchronize the stream after graph launch. By default False (no synchronization). Raises ------ RuntimeError If no CUDA graph has been captured or CUDA graphs are disabled. """ if FLAGS.NVTX_ENABLED: nvtx.push_range(self._nvtx_tags["graph_exec"]) with self._device_guard: if self._cuda_graph is None or not self._cuda_graph.is_captured: err_msg = f"No CUDA graph captured for engine '{self._name}'. " err_msg += "Ensure cuda_graph=True and warmup=True, or call execute() first." if FLAGS.NVTX_ENABLED: nvtx.pop_range() # graph_exec raise RuntimeError(err_msg) self._cuda_graph.launch() if debug: stream_synchronize(self._stream) if FLAGS.NVTX_ENABLED: nvtx.pop_range()
[docs] def direct_exec( self: Self, pointers: list[int], *, set_pointers: bool = True, no_warn: bool | None = None, verbose: bool | None = None, debug: bool | None = None, ) -> list[np.ndarray]: """ Execute the network with the given GPU memory pointers. The outputs of this function are not copied on return. The data will be updated inplace if execute or direct_exec is called. Calling this method while giving bad pointers will also cause CUDA runtime to crash and program to crash. Parameters ---------- pointers : list[int] The inputs to the network. Pointers must be in the order of expected inputs for the engine. set_pointers : bool, optional Whether to set tensor addresses before execution. If True (default), tensor addresses will be set. If False, tensor addresses are assumed to already be configured. By default True. no_warn : bool, optional If True, do not warn about usage. verbose : bool, optional Whether or not to output additional information to stdout. If not provided, will default to overall engines verbose setting. debug : bool, optional Enable intermediate stream synchronize for debugging. Returns ------- list[np.ndarray] The outputs of the network. Notes ----- This method always synchronizes the stream before returning, ensuring outputs are ready to read on the host. """ verbose = verbose if verbose is not None else self._verbose if not no_warn: LOG.warning( "Calling direct_exec is potentially dangerous, ensure all pointers and data are valid. Outputs can be overwritten inplace!", ) if FLAGS.NVTX_ENABLED: nvtx.push_range(self._nvtx_tags["direct_exec"]) with self._device_guard: # execute if self._async_v3: if set_pointers: # need to set the input pointers to match the bindings, assume in same order for i in range(len(pointers)): self._context.set_tensor_address(self._inputs[i].name, pointers[i]) self._using_engine_tensors = ( False # set flag to tell future execute calls to reset inputs ) self._context.execute_async_v3(self._stream) else: self._context.execute_async_v2( pointers + self._output_allocations, self._stream, ) if debug: stream_synchronize(self._stream) # copy outputs if self._unified_mem and self._pagelocked_mem: pass elif self._pagelocked_mem: for o_idx in range(len(self._outputs)): memcpy_device_to_host_async( self._outputs[o_idx].host_allocation, self._outputs[o_idx].allocation, self._stream, ) else: for o_idx in range(len(self._outputs)): memcpy_device_to_host( self._outputs[o_idx].host_allocation, self._outputs[o_idx].allocation, ) # make sure all operations are complete stream_synchronize(self._stream) if FLAGS.NVTX_ENABLED: nvtx.pop_range() # return the output host allocations return self._output_host_allocations
[docs] def raw_exec( self: Self, pointers: list[int], *, set_pointers: bool = True, no_warn: bool | None = None, verbose: bool | None = None, debug: bool | None = None, ) -> list[int]: """ Execute the network with the given GPU memory pointers. The outputs of this function are the direct GPU pointers of the output allocations. Parameters ---------- pointers : list[int] The inputs to the network. Pointers must be in the order of expected inputs for the engine. set_pointers : bool, optional Whether to set tensor addresses before execution. If True (default), tensor addresses will be set. If False, tensor addresses are assumed to already be configured. By default True. no_warn : bool, optional If True, do not warn about usage. verbose : bool, optional Whether or not to output additional information to stdout. If not provided, will default to overall engines verbose setting. debug : bool, optional Enable intermediate stream synchronize for debugging. Returns ------- list[int] The pointers to the network outputs. Notes ----- This method does NOT synchronize the stream by default. The caller is responsible for synchronization if needed. Use debug=True to force synchronization after execution. """ verbose = verbose if verbose is not None else self._verbose if not no_warn: LOG.warning( "Calling raw_exec is potentially dangerous, ensure all pointers and data are valid. Outputs can be overwritten inplace!", ) if FLAGS.NVTX_ENABLED: nvtx.push_range(self._nvtx_tags["raw_exec"]) with self._device_guard: # execute if self._async_v3: if set_pointers: # need to set the input pointers to match the bindings, assume in same order for i in range(len(pointers)): self._context.set_tensor_address(self._inputs[i].name, pointers[i]) self._using_engine_tensors = ( False # set flag to tell future execute calls to reset inputs ) self._context.execute_async_v3(self._stream) else: self._context.execute_async_v2( pointers + self._output_allocations, self._stream, ) if debug: stream_synchronize(self._stream) if FLAGS.NVTX_ENABLED: nvtx.pop_range() # return the pointers to the output allocations return self._output_allocations