Source code for trtutils.core._engine
# Copyright (c) 2024-2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
from trtutils._config import CONFIG
from trtutils._flags import FLAGS
from trtutils._log import LOG
from trtutils.compat._libs import trt
from ._device import Device
from ._stream import create_stream
if TYPE_CHECKING:
from trtutils.compat._libs import cudart_bindings as cudart
[docs]
def create_engine(
engine_path: Path | str,
stream: cudart.cudaStream_t | None = None,
dla_core: int | None = None,
device: int | None = None,
*,
no_warn: bool | None = None,
) -> tuple[trt.ICudaEngine, trt.IExecutionContext, trt.ILogger, cudart.cudaStream_t]:
"""
Load a serialized engine from disk.
Parameters
----------
engine_path : Path | str
The path to the serialized engine file.
stream : cudart.cudaStream_t, optional
When an already made stream is passed, no new stream is created.
Useful if you want multiple engines to share the same stream.
Although there is no explicit link between engine and stream, the stream
returned by this function should be used for execution.
dla_core : int, optional
The DLA core to assign DLA layers of the engine to. Default is None.
If None, any DLA layers will be assigned to DLA core 0.
device : int, optional
The CUDA device index to create the engine on. Default is None,
which uses the current device.
no_warn : bool | None, optional
If True, suppresses warnings from TensorRT. Default is None.
Returns
-------
tuple[trt.ICudaEngine, trt.IExecutionContext, trt.ILogger, cudart.cudaStream_t]
The deserialized engine, execution context, logger used, and stream created.
Logger returned is the same as the input logger if not None.
Raises
------
FileNotFoundError
If the engine file is not found.
RuntimeError
If the TRT runtime could not be created.
If the engine could not be deserialized.
If the execution context could not be created.
"""
# load libnvinfer plugins
CONFIG.load_plugins()
engine_path = Path(engine_path) if isinstance(engine_path, str) else engine_path
if not engine_path.exists():
err_msg = f"Engine file not found: {engine_path}"
raise FileNotFoundError(err_msg)
with Device(device):
# load the engine from file
# explicitly a thread-safe operation
# https://docs.nvidia.com/deeplearning/tensorrt/latest/architecture/how-trt-works.html
runtime = trt.Runtime(LOG)
if dla_core is not None:
runtime.DLA_core = dla_core
with Path.open(engine_path, "rb") as f:
if runtime is None:
err_msg = "Failed to create TRT runtime"
raise RuntimeError(err_msg)
if no_warn:
with LOG.suppress():
engine = runtime.deserialize_cuda_engine(f.read())
else:
engine = runtime.deserialize_cuda_engine(f.read())
# final check on engine
if engine is None:
err_msg = f"Failed to deserialize engine from {engine_path}"
raise RuntimeError(err_msg)
# create the execution context
# explicitly a thread-safe operation
# https://docs.nvidia.com/deeplearning/tensorrt/latest/architecture/how-trt-works.html
context = engine.create_execution_context()
if context is None:
err_msg = "Failed to create execution context"
raise RuntimeError(err_msg)
# create a cudart stream
if stream is None:
stream = create_stream()
return engine, context, LOG, stream
[docs]
def get_engine_names(
engine: trt.ICudaEngine,
) -> tuple[list[str], list[str]]:
"""
Get the input/output names of a TensorRT engine in order.
Parameters
----------
engine : trt.ICudaEngine
The TensorRT engine to get the input and output names from.
Returns
-------
tuple[list[str], list[str]]
The input and output tensors in order of enumeration.
"""
input_names: list[str] = []
output_names: list[str] = []
num_tensors = range(engine.num_io_tensors) if FLAGS.TRT_10 else range(engine.num_bindings)
for i in num_tensors:
# get the tensor name in-order
if FLAGS.TRT_10:
tensor_name = engine.get_tensor_name(i)
is_input = engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT
else:
tensor_name = engine.get_binding_name(i)
is_input = engine.binding_is_input(i)
# store
if is_input:
input_names.append(tensor_name)
else:
output_names.append(tensor_name)
return input_names, output_names