# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)## MIT License# mypy: disable-error-code="import-untyped"from__future__importannotationsimportcontextlibfrompathlibimportPathfromtypingimportTYPE_CHECKINGwithcontextlib.suppress(Exception):importtensorrtastrtfromtrtutils._configimportCONFIGfromtrtutils._logimportLOGfrom._streamimportcreate_streamifTYPE_CHECKING:try:importcuda.bindings.cudartascudartexcept(ImportError,ModuleNotFoundError):fromcudaimportcudart
[docs]defcreate_engine(engine_path:Path|str,stream:cudart.cudaStream_t|None=None,dla_core:int|None=None,*,no_warn:bool|None=None,)->tuple[trt.ICudaEngine,trt.IExecutionContext,trt.ILogger,cudart.cudaStream_t]:""" Load a serialized engine from disk. Parameters ---------- engine_path : Path | str The path to the serialized engine file. stream : cudart.cudaStream_t, optional When an already made stream is passed, no new stream is created. Useful if you want multiple engines to share the same stream. Although there is no explicit link between engine and stream, the stream returned by this function should be used for execution. dla_core : int, optional The DLA core to assign DLA layers of the engine to. Default is None. If None, any DLA layers will be assigned to DLA core 0. no_warn : bool | None, optional If True, suppresses warnings from TensorRT. Default is None. Returns ------- tuple[trt.ICudaEngine, trt.IExecutionContext, trt.ILogger, cudart.cudaStream_t] The deserialized engine, execution context, logger used, and stream created. Logger returned is the same as the input logger if not None. Raises ------ FileNotFoundError If the engine file is not found. RuntimeError If the TRT runtime could not be created. If the engine could not be deserialized. If the execution context could not be created. """# load libnvinfer pluginsCONFIG.load_plugins()engine_path=Path(engine_path)ifisinstance(engine_path,str)elseengine_pathifnotengine_path.exists():err_msg=f"Engine file not found: {engine_path}"raiseFileNotFoundError(err_msg)# load the engine from file# explicitly a thread-safe operation# https://docs.nvidia.com/deeplearning/tensorrt/latest/architecture/how-trt-works.htmlruntime=trt.Runtime(LOG)ifdla_coreisnotNone:runtime.DLA_core=dla_corewithPath.open(engine_path,"rb")asf:ifruntimeisNone:err_msg="Failed to create TRT runtime"raiseRuntimeError(err_msg)ifno_warn:withLOG.suppress():engine=runtime.deserialize_cuda_engine(f.read())else:engine=runtime.deserialize_cuda_engine(f.read())# final check on engineifengineisNone:err_msg=f"Failed to deserialize engine from {engine_path}"raiseRuntimeError(err_msg)# create the execution context# explicitly a thread-safe operation# https://docs.nvidia.com/deeplearning/tensorrt/latest/architecture/how-trt-works.htmlcontext=engine.create_execution_context()ifcontextisNone:err_msg="Failed to create execution context"raiseRuntimeError(err_msg)# create a cudart streamifstreamisNone:stream=create_stream()returnengine,context,LOG,stream