Source code for trtutils.core._engine

# Copyright (c) 2024-2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

from trtutils._config import CONFIG
from trtutils._flags import FLAGS
from trtutils._log import LOG
from trtutils.compat._libs import trt

from ._device import Device
from ._stream import create_stream

if TYPE_CHECKING:
    from trtutils.compat._libs import cudart_bindings as cudart


[docs] def create_engine( engine_path: Path | str, stream: cudart.cudaStream_t | None = None, dla_core: int | None = None, device: int | None = None, *, no_warn: bool | None = None, ) -> tuple[trt.ICudaEngine, trt.IExecutionContext, trt.ILogger, cudart.cudaStream_t]: """ Load a serialized engine from disk. Parameters ---------- engine_path : Path | str The path to the serialized engine file. stream : cudart.cudaStream_t, optional When an already made stream is passed, no new stream is created. Useful if you want multiple engines to share the same stream. Although there is no explicit link between engine and stream, the stream returned by this function should be used for execution. dla_core : int, optional The DLA core to assign DLA layers of the engine to. Default is None. If None, any DLA layers will be assigned to DLA core 0. device : int, optional The CUDA device index to create the engine on. Default is None, which uses the current device. no_warn : bool | None, optional If True, suppresses warnings from TensorRT. Default is None. Returns ------- tuple[trt.ICudaEngine, trt.IExecutionContext, trt.ILogger, cudart.cudaStream_t] The deserialized engine, execution context, logger used, and stream created. Logger returned is the same as the input logger if not None. Raises ------ FileNotFoundError If the engine file is not found. RuntimeError If the TRT runtime could not be created. If the engine could not be deserialized. If the execution context could not be created. """ # load libnvinfer plugins CONFIG.load_plugins() engine_path = Path(engine_path) if isinstance(engine_path, str) else engine_path if not engine_path.exists(): err_msg = f"Engine file not found: {engine_path}" raise FileNotFoundError(err_msg) with Device(device): # load the engine from file # explicitly a thread-safe operation # https://docs.nvidia.com/deeplearning/tensorrt/latest/architecture/how-trt-works.html runtime = trt.Runtime(LOG) if dla_core is not None: runtime.DLA_core = dla_core with Path.open(engine_path, "rb") as f: if runtime is None: err_msg = "Failed to create TRT runtime" raise RuntimeError(err_msg) if no_warn: with LOG.suppress(): engine = runtime.deserialize_cuda_engine(f.read()) else: engine = runtime.deserialize_cuda_engine(f.read()) # final check on engine if engine is None: err_msg = f"Failed to deserialize engine from {engine_path}" raise RuntimeError(err_msg) # create the execution context # explicitly a thread-safe operation # https://docs.nvidia.com/deeplearning/tensorrt/latest/architecture/how-trt-works.html context = engine.create_execution_context() if context is None: err_msg = "Failed to create execution context" raise RuntimeError(err_msg) # create a cudart stream if stream is None: stream = create_stream() return engine, context, LOG, stream
[docs] def get_engine_names( engine: trt.ICudaEngine, ) -> tuple[list[str], list[str]]: """ Get the input/output names of a TensorRT engine in order. Parameters ---------- engine : trt.ICudaEngine The TensorRT engine to get the input and output names from. Returns ------- tuple[list[str], list[str]] The input and output tensors in order of enumeration. """ input_names: list[str] = [] output_names: list[str] = [] num_tensors = range(engine.num_io_tensors) if FLAGS.TRT_10 else range(engine.num_bindings) for i in num_tensors: # get the tensor name in-order if FLAGS.TRT_10: tensor_name = engine.get_tensor_name(i) is_input = engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT else: tensor_name = engine.get_binding_name(i) is_input = engine.binding_is_input(i) # store if is_input: input_names.append(tensor_name) else: output_names.append(tensor_name) return input_names, output_names