# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)## MIT License# mypy: disable-error-code="import-untyped"from__future__importannotationsimportcontextlibfrompathlibimportPathfromtypingimportTYPE_CHECKINGwithcontextlib.suppress(ImportError):importtensorrtastrtfromtrtutils._logimportLOGfromtrtutils.coreimportcuda_malloc,memcpy_host_to_deviceifTYPE_CHECKING:fromtyping_extensionsimportSelffrom._batcherimportAbstractBatcher
[docs]classEngineCalibrator(trt.IInt8EntropyCalibrator2):# type: ignore[misc]"""Implements the trt.IInt8EntropyCalibrator2."""def__init__(self:Self,calibration_cache:Path|str|None=None,)->None:""" Create an EngineCalibrator. Parameters ---------- calibration_cache : Path, str, optional The path to the calibration cache. """super().__init__()self._cache_path:Path=(Path(calibration_cache).resolve()ifcalibration_cacheisnotNoneelsePath("calibration.cache").resolve())self._batcher:AbstractBatcher|None=None
[docs]defset_batcher(self:Self,batcher:AbstractBatcher)->None:"""Set the batcher."""self._batcher=batcher
[docs]defget_batch_size(self:Self)->int:""" Get the batch size. Overrides from trt.IInt8EntropyCalibrator2. Returns ------- int The batch size """ifself._batcher:returnself._batcher.batch_sizereturn1
[docs]defget_batch(self:Self,names:list[str])->list[int]|None:# noqa: ARG002""" Get the next batch of data. Overrides from trt.IInt8EntropyCalibrator2. Parameters ---------- names : list[str] The list of inputs, if useful to define the batch. Returns ------- list[int] GPU-Memory pointers of the next batch """# if we dont have an image batcher, dont handle calibrationifself._batcherisNone:returnNone# if we do load the imagebatch=self._batcher.get_next_batch()ifbatchisNone:returnNone# allocate GPU memory for the batch# return the GPU pointerptr=cuda_malloc(batch.nbytes)memcpy_host_to_device(ptr,batch)return[ptr]
[docs]defread_calibration_cache(self:Self)->bytes|None:""" Read the calibration cache file if it exists. Overrides from trt.IInt8EntropyCalibrator2. Returns ------- bytes | None The calibration cache contents if it exists """ifself._cache_pathisNone:returnNoneifnotself._cache_path.exists():returnNonewithself._cache_path.open("rb")asf:LOG.debug(f"Reading calibration cache file: {self._cache_path}")data:bytes=f.read()returndata
[docs]defwrite_calibration_cache(self:Self,cache:bytes)->None:""" Write the calibration date to the calibration cache file. Overrides from trt.IInt8EntropyCalibrator2. Parameters ---------- cache : bytes The calibration data generated. """ifself._cache_pathisNone:returnwithself._cache_path.open("wb")asf:LOG.debug(f"Writing calibration cache file: {self._cache_path}")f.write(cache)