# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)## MIT License# mypy: disable-error-code="import-untyped"from__future__importannotationsimportcontextlibfrompathlibimportPathwithcontextlib.suppress(ImportError):importtensorrtastrtfromtrtutils._flagsimportFLAGSfromtrtutils.core._engineimportcreate_enginefromtrtutils.core._streamimportdestroy_stream
[docs]defget_engine_names(engine:Path|str|trt.ICudaEngine,)->tuple[list[str],list[str]]:""" Get the input/output names of a TensorRT engine in order. Parameters ---------- engine : Path | str | trt.ICudaEngine Path to the TensorRT engine file or an already loaded engine Returns ------- tuple[list[str], list[str]] The input and output tensors in order of enumeration. """loaded=Falseifisinstance(engine,(Path,str)):engine,context,logger,stream=create_engine(engine)loaded=Trueinput_names:list[str]=[]output_names:list[str]=[]num_tensors=(range(engine.num_io_tensors)ifFLAGS.TRT_10elserange(engine.num_bindings))foriinnum_tensors:# get the tensor name in-orderifFLAGS.TRT_10:tensor_name=engine.get_tensor_name(i)is_input=engine.get_tensor_mode(tensor_name)==trt.TensorIOMode.INPUTelse:tensor_name=engine.get_binding_name(i)is_input=engine.binding_is_input(i)# storeifis_input:input_names.append(tensor_name)else:output_names.append(tensor_name)ifloaded:delenginedelcontextdelloggerdestroy_stream(stream)returninput_names,output_names