Source code for trtutils.builder._calibrator

# Copyright (c) 2024-2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations

import tempfile
from pathlib import Path
from typing import TYPE_CHECKING

from trtutils._log import LOG
from trtutils.compat._libs import trt
from trtutils.core._memory import cuda_malloc, memcpy_host_to_device

if TYPE_CHECKING:
    from typing_extensions import Self

    from ._batcher import AbstractBatcher


[docs] class EngineCalibrator(trt.IInt8EntropyCalibrator2): """Implements the trt.IInt8EntropyCalibrator2.""" def __init__( self: Self, calibration_cache: Path | str | None = None, ) -> None: """ Create an EngineCalibrator. Parameters ---------- calibration_cache : Path, str, optional The path to the calibration cache. """ super().__init__() if calibration_cache is not None: self._cache_path: Path = Path(calibration_cache).resolve() else: with tempfile.NamedTemporaryFile(suffix=".cache", delete=False) as tmp: self._cache_path = Path(tmp.name) LOG.warning( "No calibration cache path provided, using a temporary file." " Calibration data will not persist across runs." ) self._batcher: AbstractBatcher | None = None
[docs] def set_batcher(self: Self, batcher: AbstractBatcher) -> None: """Set the batcher.""" self._batcher = batcher
[docs] def get_batch_size(self: Self) -> int: """ Get the batch size. Overrides from trt.IInt8EntropyCalibrator2. Returns ------- int The batch size """ if self._batcher: return self._batcher.batch_size return 1
[docs] def get_batch(self: Self, names: list[str]) -> list[int] | None: # noqa: ARG002 """ Get the next batch of data. Overrides from trt.IInt8EntropyCalibrator2. Parameters ---------- names : list[str] The list of inputs, if useful to define the batch. Returns ------- list[int] GPU-Memory pointers of the next batch """ # if we dont have an image batcher, dont handle calibration if self._batcher is None: return None # if we do load the image batch = self._batcher.get_next_batch() if batch is None: return None # allocate GPU memory for the batch # return the GPU pointer ptr = cuda_malloc(batch.nbytes) memcpy_host_to_device(ptr, batch) return [ptr]
[docs] def read_calibration_cache(self: Self) -> bytes | None: """ Read the calibration cache file if it exists. Overrides from trt.IInt8EntropyCalibrator2. Returns ------- bytes | None The calibration cache contents if it exists """ if self._cache_path is None: return None if not self._cache_path.exists(): return None with self._cache_path.open("rb") as f: LOG.debug(f"Reading calibration cache file: {self._cache_path}") data: bytes = f.read() return data
[docs] def write_calibration_cache(self: Self, cache: bytes) -> None: """ Write the calibration date to the calibration cache file. Overrides from trt.IInt8EntropyCalibrator2. Parameters ---------- cache : bytes The calibration data generated. """ if self._cache_path is None: return with self._cache_path.open("wb") as f: LOG.debug(f"Writing calibration cache file: {self._cache_path}") f.write(cache)