Source code for trtutils.inspect._names

# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations

from pathlib import Path

from trtutils._engine import TRTEngine


[docs] def get_engine_names( engine: TRTEngine | Path | str, ) -> tuple[list[str], list[str]]: """ Get the input/output names of a TensorRT engine in order. Parameters ---------- engine : Path | str | trt.ICudaEngine Path to the TensorRT engine file or an already loaded engine Returns ------- tuple[list[str], list[str]] The input and output tensors in order of enumeration. """ loaded = False if isinstance(engine, (Path, str)): engine = TRTEngine(engine, warmup=False) loaded = True input_names, output_names = engine.input_names, engine.output_names if loaded: del engine return input_names, output_names