Source code for trtutils.builder._onnx
# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations
from pathlib import Path
from trtutils._config import CONFIG
from trtutils._log import LOG
from trtutils.compat._libs import trt
[docs]
def read_onnx(
onnx: Path | str,
workspace: float = 4.0,
) -> tuple[
trt.INetworkDefinition,
trt.IBuilder,
trt.IBuilderConfig,
trt.IOnnxParser,
]:
"""
Open an ONNX model and generate TensorRT network, builder, config, and parser.
Parameters
----------
onnx : Path, str
The path to the onnx model.
workspace : float
The size of the workspace in gigabytes.
Default is 4.0 GiB.
Returns
-------
tuple[trt.INetworkDefinition, trt.IBuilder, trt.IBuilderConfig, trt.IOnnxParser]
The network, builder, config, and parser.
Raises
------
FileNotFoundError
If the onnx model does not exist
IsADirectoryError
If the onnx model path is a directory
ValueError
If the onnx model path does not have .onnx extension
RuntimeError
If the ONNX model cannot be parsed
"""
# load libnvinfer plugins
CONFIG.load_plugins()
onnx_path = Path(onnx).resolve()
if not onnx_path.exists():
err_msg = f"Could not find ONNX model at: {onnx_path}"
raise FileNotFoundError(err_msg)
if onnx_path.is_dir():
err_msg = f"Path given is a directory: {onnx_path}"
raise IsADirectoryError(err_msg)
if onnx_path.suffix != ".onnx":
err_msg = "File does not have .onnx extension"
raise ValueError(err_msg)
builder = trt.Builder(LOG)
config = builder.create_builder_config()
# setup the workspace size
workspace_bytes = int(workspace * (1 << 30))
if hasattr(config, "max_workspace_size"):
config.max_workspace_size = workspace_bytes
else:
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
# make network
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH),
)
# setup parser
parser = trt.OnnxParser(network, LOG)
with onnx_path.open("rb") as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
LOG.error(parser.get_error(error))
err_msg = "Cannot parse ONNX file"
raise RuntimeError(err_msg)
return network, builder, config, parser