# Copyright (c) 2024-2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations
import shutil
from pathlib import Path
from typing import TYPE_CHECKING
from trtutils._config import CONFIG
from trtutils._flags import FLAGS
from trtutils._log import LOG
from trtutils.compat._libs import trt
from trtutils.core import cache as caching_tools
from trtutils.core._device import Device
from trtutils.core.cache import query_timing_cache, save_timing_cache_to_global
from ._calibrator import EngineCalibrator
from ._onnx import read_onnx
from ._utils import get_check_dla
ProgressBar = None
if FLAGS.BUILD_PROGRESS:
from ._progress import ProgressBar
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ._batcher import AbstractBatcher
_MIN_OPTIM_LEVEL = 0
_MAX_OPTIM_LEVEL = 5
[docs]
def build_engine(
onnx: Path | str,
output: Path | str,
default_device: trt.DeviceType | str = trt.DeviceType.GPU,
workspace: float = 4.0,
dla_core: int | None = None,
calibration_cache: Path | str | None = None,
data_batcher: AbstractBatcher | None = None,
layer_precision: list[tuple[int, trt.DataType | None]] | None = None,
layer_device: list[tuple[int, trt.DeviceType | None]] | None = None,
shapes: Sequence[tuple[str, tuple[int, ...]]] | None = None,
input_tensor_formats: list[tuple[str, trt.DataType, trt.TensorFormat]] | None = None,
output_tensor_formats: list[tuple[str, trt.DataType, trt.TensorFormat]] | None = None,
hooks: list[Callable[[trt.INetworkDefinition], trt.INetworkDefinition]] | None = None,
optimization_level: int = 3,
profiling_verbosity: trt.ProfilingVerbosity | None = None,
tiling_optimization_level: trt.TilingOptimizationLevel | None = None,
tiling_l2_cache_limit: int | None = None,
device: int | None = None,
*,
timing_cache: Path | str | bool | None = None,
gpu_fallback: bool = False,
direct_io: bool = False,
prefer_precision_constraints: bool = False,
reject_empty_algorithms: bool = False,
ignore_timing_mismatch: bool = False,
fp16: bool | None = None,
fp8: bool | None = None,
int8: bool | None = None,
cache: bool | None = None,
verbose: bool | None = None,
) -> None:
"""
Build a TensorRT engine from an ONNX model.
The order in which operations occur inside build_engine:
1. Parse the ONNX model
2. Apply any network hooks
3. Create optimization profile and apply any manual shapes
4. Apply builder flags (precision constraints, empty algorithms, direct I/O)
5. Configure tensor formats if specified
6. Configure precision (FP16, FP8, INT8)
7. Set default device and DLA core
8. Apply individual layer precision and device settings
9. Set up timing cache
10. Build the engine
11. Save timing cache and engine
Parameters
----------
onnx : Path, str
The path to the onnx model.
output : Path, str
The location to save the TensorRT engine.
default_device : trt.DeviceType, str, optional
The device to use for the engine.
By default, trt.DeviceType.GPU.
Options are trt.DeviceType.GPU, trt.DeviceType.DLA, or a string
of "gpu" or "dla".
timing_cache : Path, str, bool, optional
Where to store the timing cache data.
Can be a Path or str to a specific file, "global" or True to use
the global timing cache stored in the trtutils cache directory,
or None to not use a timing cache.
Default is None.
workspace : float
The size of the workspace in gigabytes.
Default is 4.0 GiB.
calibration_cache : Path, str, optional
The path to the calibration cache.
data_batcher : AbstractBatcher, optional
The data batcher to use for calibration.
dla_core : int, optional
The DLA core to build the engine for.
By default, None or build the engine for GPU.
layer_precision : list[tuple[int, trt.DataType | None]], optional
The precision to use for specific layers.
By default, None.
layer_device : list[tuple[int, trt.DeviceType | None]], optional
The device to use for specific layers.
By default, None.
shapes : list[tuple[str, tuple[int, ...]]], optional
A list of (input_name, shape) pairs to specify the shapes of the input layers.
For example, shapes=[("images", (1, 3, imgsz, imgsz))] will set the input
“images” to a fixed shape. This shape will be used as the min, optimal,
and max shape for the binding.
By default, None.
input_tensor_formats : list[tuple[str, trt.DataType, trt.TensorFormat]], optional
A list of (name, dtype format) to allow deep specification of input layers.
For example, input_tensor_formats=[("input", trt.DataType.UINT8, trt.TensorFormat.HWC)]
By default, None
output_tensor_formats : list[tuple[str, trt.DataType, trt.TensorFormat]], optional
A list of (name, dtype format) to allow deep specification of output layers.
For example, output_tensor_formats=[("output", trt.DataType.HALF, trt.TensorFormat.LINEAR)]
By default, None
hooks : list[Callable[[trt.INetworkDefinition], trt.INetworkDefinition]], optional
An optional list of 'hook' functions to modify the TensorRT network before
the remainder of the build phase occurs.
By default, None
optimization_level : int, optional
Optimization level to apply to the TensorRT builder config (0-5).
By default, 3.
profiling_verbosity : trt.ProfilingVerbosity | None, optional
Level of detail for profiling information in the built engine.
Options are: trt.ProfilingVerbosity.NONE, trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
trt.ProfilingVerbosity.DETAILED
DETAILED is recommended for best layer names when using profile_engine.
By default, None (uses TensorRT's default).
tiling_optimization_level : int, optional
Tiling optimization level to enable cross-kernel tiled inference.
By default, 0 (no tiling optimization).
tiling_l2_cache_limit : int, None, optional
L2 cache limit (in bytes) for tiling optimization.
By default, None (TensorRT manages the default value).
device : int, optional
The CUDA device index to build the engine on. Default is None,
which uses the current device.
gpu_fallback : bool
Whether or not to allow GPU fallback for unsupported layers
when building the engine for DLA.
By default, False
direct_io : bool
Use direct IO for the engine.
By default, False
prefer_precision_constraints : bool
Whether or not to prefer precision constraints.
By default, False
reject_empty_algorithms : bool
Whether or not to reject empty algorithms.
By default, False
ignore_timing_mismatch : bool
Whether or not to allow different CUDA device generated timing
caches to be used in the building of engines.
By default, False
fp16 : bool, optional
If True, quantize the engine to FP16 precision.
fp8 : bool, optional
If True, enable FP8 precision for the engine.
Requires compute capability >= 8.9 (Ada Lovelace / Hopper or newer).
int8 : bool, optional
If True, quantize the engine to INT8 precision.
cache : bool, optional
Whether or not to cache the engine in the trtutils engine cache.
If an existing version is found will use that.
Uses the name of the output file to assess if the engine has been compiled before.
As such, naming the output 'engine', 'model' or similiar will result in
unintended caching behavior.
By default None, will not cache the engine.
verbose : bool, optional
If True, print verbose output.
By default, None or False
Raises
------
RuntimeError
If the ONNX model cannot be parsed
RuntimeError
If the TensorRT engines fails to build
ValueError
If layer is manually assigned to DLA and DLA is not supported
and gpu_fallback is False
"""
# load libnvinfer plugins
CONFIG.load_plugins()
output_path = Path(output).resolve()
# first thing is to check cache
if cache:
exists, location = caching_tools.query(output_path.stem)
if exists:
shutil.copy(location, output_path)
return
# validate and handle timing_cache parameter
use_global_timing_cache = False
if timing_cache is True or timing_cache == "global":
use_global_timing_cache = True
timing_cache_path = None
elif timing_cache is None:
timing_cache_path = None
elif isinstance(timing_cache, (Path, str)):
timing_cache_path = Path(timing_cache).resolve()
else:
err_msg = (
f"Invalid timing_cache value: {timing_cache}. "
"Must be None, Path, str, True, or 'global'."
)
raise ValueError(err_msg)
# match the device
valid_gpu = ["gpu", "GPU"]
valid_dla = ["dla", "DLA"]
if isinstance(default_device, str):
if default_device not in valid_gpu + valid_dla:
err_msg = (
f"Invalid default device: {default_device}. Must be one of: {valid_gpu + valid_dla}"
)
raise ValueError(err_msg)
default_device = trt.DeviceType.GPU if default_device in valid_gpu else trt.DeviceType.DLA
else:
if default_device not in [trt.DeviceType.GPU, trt.DeviceType.DLA]:
err_msg = (
f"Invalid default device: {default_device}. Must be one of: {valid_gpu + valid_dla}"
)
raise ValueError(err_msg)
default_device = (
trt.DeviceType.GPU if default_device == trt.DeviceType.GPU else trt.DeviceType.DLA
)
# read the onnx model
network, builder, config, _ = read_onnx(
onnx,
workspace,
)
# handle all hooks to start
if hooks is not None:
for hook in hooks:
network = hook(network)
# helper function for checking if layer can run on DLA
check_dla: Callable[[trt.ILayer], bool] = get_check_dla(config)
if verbose and FLAGS.BUILD_PROGRESS and ProgressBar is not None:
LOG.debug("Applying ProgressBar to config")
config.progress_monitor = ProgressBar()
# create profile and config
profile = builder.create_optimization_profile()
# handle if manual shapes were passed for inputs
if shapes:
for input_name, shape in shapes:
# set the minimum, optimal, maximum to all the same
profile.set_shape(input_name, shape, shape, shape)
config.add_optimization_profile(profile)
if not (_MIN_OPTIM_LEVEL <= optimization_level <= _MAX_OPTIM_LEVEL):
err_msg = "Builder optimization level must be between 0 and 5."
raise ValueError(err_msg)
config.builder_optimization_level = int(optimization_level)
# handle profiling verbosity
if profiling_verbosity is not None:
config.profiling_verbosity = profiling_verbosity
# handle tiling optimization
if tiling_optimization_level is not None:
config.tiling_optimization_level = tiling_optimization_level
if tiling_l2_cache_limit is not None:
config.l2_limit_for_tiling = tiling_l2_cache_limit
# handle some flags
if prefer_precision_constraints:
config.set_flag(trt.BuilderFlag.PREFER_PRECISION_CONSTRAINTS)
if reject_empty_algorithms:
config.set_flag(trt.BuilderFlag.REJECT_EMPTY_ALGORITHMS)
# handle custom datatype/format for input/output tensors
if (input_tensor_formats is not None or output_tensor_formats is not None) and not direct_io:
LOG.warning("Direct IO not enabled, but some tensor formats specified. Enabling direct IO.")
direct_io = True
if direct_io:
config.set_flag(trt.BuilderFlag.DIRECT_IO)
if input_tensor_formats is not None:
for tensor_name, tensor_dtype, tensor_format in input_tensor_formats:
found = False
for idx in range(network.num_inputs):
inp = network.get_input(idx)
if inp.name == tensor_name:
inp.dtype = tensor_dtype
inp.allowed_formats = 1 << int(tensor_format)
found = True
break
if not found:
LOG.warning(f"Input tensor '{tensor_name}' not found in network")
if output_tensor_formats is not None:
for tensor_name, tensor_dtype, tensor_format in output_tensor_formats:
found = False
for idx in range(network.num_outputs):
out = network.get_output(idx)
if out.name == tensor_name:
out.dtype = tensor_dtype
out.allowed_formats = 1 << int(tensor_format)
found = True
break
if not found:
LOG.warning(f"Output tensor '{tensor_name}' not found in network")
# setup the precision sets
if fp16 or fp8 or int8:
# want to enable fp16 for int8, fp8, and fp16 since fp16 may be faster
if not builder.platform_has_fast_fp16:
LOG.warning("Platform does not have native fast FP16.")
config.set_flag(trt.BuilderFlag.FP16)
if fp8:
config.set_flag(trt.BuilderFlag.FP8)
if int8:
if not builder.platform_has_fast_int8:
LOG.warning("Platform does not have native fast INT8.")
config.set_flag(trt.BuilderFlag.INT8)
if calibration_cache is None and data_batcher is None:
err_msg = "Neither calibration cache or data batcher passed during model building, INT8 build will not be accurate."
LOG.warning(err_msg)
config.int8_calibrator = EngineCalibrator(calibration_cache=calibration_cache)
if data_batcher is not None:
config.int8_calibrator.set_batcher(data_batcher)
# assign the default device
config.default_device_type = default_device
# handle DLA assignment
if dla_core is not None:
config.DLA_core = dla_core
if gpu_fallback:
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
# handle individual layer precision
if layer_precision is not None:
# remove the validation since we bundle the layer idx with the precision
# # validate length
# if len(layer_precision) != network.num_layers:
# err_msg = "Layer precision list must be the same length as the number of layers in the network."
# raise ValueError(err_msg)
# handle precision assignment
for layer_idx, precision in layer_precision:
if precision is None:
continue
layer = network.get_layer(layer_idx)
layer.precision = precision
# handle individual layer device
if layer_device is not None:
# remove the validation since we bundle the layer idx with the device
# # validate length
# if len(layer_device) != network.num_layers:
# err_msg = (
# "Layer device list must be the same length as the number of layers in the network."
# )
# raise ValueError(err_msg)
# handle device assignment
for layer_idx, layer_dev in layer_device:
if layer_dev is None:
continue
layer = network.get_layer(layer_idx)
# assess if can run on DLA
if layer_dev == trt.DeviceType.DLA and not check_dla(layer):
err_msg = f"Layer {layer.name} (type: {layer.type}) cannot run on DLA"
if gpu_fallback:
err_msg += ", using GPU fallback"
LOG.warning(err_msg)
else:
raise ValueError(err_msg)
else:
config.set_device_type(layer, layer_dev)
# load/setup the timing cache
t_cache: trt.ITimingCache | None = None
if use_global_timing_cache:
# use global timing cache from cache directory
exists, global_cache_path = query_timing_cache()
buffer = b""
if exists:
with global_cache_path.open("rb") as timing_cache_file:
buffer = timing_cache_file.read()
t_cache = config.create_timing_cache(buffer)
config.set_timing_cache(t_cache, ignore_mismatch=ignore_timing_mismatch)
elif timing_cache_path:
# use specified timing cache path
buffer = b""
if timing_cache_path.exists():
with timing_cache_path.open("rb") as timing_cache_file:
buffer = timing_cache_file.read()
t_cache = config.create_timing_cache(buffer)
config.set_timing_cache(t_cache, ignore_mismatch=ignore_timing_mismatch)
# build the engine
with Device(device):
if FLAGS.BUILD_SERIALIZED:
engine_bytes = builder.build_serialized_network(network, config)
else:
engine_bytes = builder.build_engine(network, config)
# save the timing cache
if use_global_timing_cache:
# save to global timing cache in cache directory
post_t_cache = config.get_timing_cache()
save_timing_cache_to_global(post_t_cache, overwrite=True)
elif timing_cache_path:
# save to specified timing cache path
post_t_cache = config.get_timing_cache()
with timing_cache_path.open("wb") as f:
f.write(memoryview(post_t_cache.serialize()))
if engine_bytes is None:
err_msg = "Failed to build engine."
raise RuntimeError(err_msg)
with output_path.open("wb") as f:
f.write(engine_bytes)
if cache:
caching_tools.store(output_path, overwrite=False, delete_source=False)