Source code for trtutils.builder._calibrator
# Copyright (c) 2024-2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING
from trtutils._log import LOG
from trtutils.compat._libs import trt
from trtutils.core._memory import cuda_malloc, memcpy_host_to_device
if TYPE_CHECKING:
from typing_extensions import Self
from ._batcher import AbstractBatcher
[docs]
class EngineCalibrator(trt.IInt8EntropyCalibrator2):
"""Implements the trt.IInt8EntropyCalibrator2."""
def __init__(
self: Self,
calibration_cache: Path | str | None = None,
) -> None:
"""
Create an EngineCalibrator.
Parameters
----------
calibration_cache : Path, str, optional
The path to the calibration cache.
"""
super().__init__()
if calibration_cache is not None:
self._cache_path: Path = Path(calibration_cache).resolve()
else:
with tempfile.NamedTemporaryFile(suffix=".cache", delete=False) as tmp:
self._cache_path = Path(tmp.name)
LOG.warning(
"No calibration cache path provided, using a temporary file."
" Calibration data will not persist across runs."
)
self._batcher: AbstractBatcher | None = None
[docs]
def set_batcher(self: Self, batcher: AbstractBatcher) -> None:
"""Set the batcher."""
self._batcher = batcher
[docs]
def get_batch_size(self: Self) -> int:
"""
Get the batch size.
Overrides from trt.IInt8EntropyCalibrator2.
Returns
-------
int
The batch size
"""
if self._batcher:
return self._batcher.batch_size
return 1
[docs]
def get_batch(self: Self, names: list[str]) -> list[int] | None: # noqa: ARG002
"""
Get the next batch of data.
Overrides from trt.IInt8EntropyCalibrator2.
Parameters
----------
names : list[str]
The list of inputs, if useful to define the batch.
Returns
-------
list[int]
GPU-Memory pointers of the next batch
"""
# if we dont have an image batcher, dont handle calibration
if self._batcher is None:
return None
# if we do load the image
batch = self._batcher.get_next_batch()
if batch is None:
return None
# allocate GPU memory for the batch
# return the GPU pointer
ptr = cuda_malloc(batch.nbytes)
memcpy_host_to_device(ptr, batch)
return [ptr]
[docs]
def read_calibration_cache(self: Self) -> bytes | None:
"""
Read the calibration cache file if it exists.
Overrides from trt.IInt8EntropyCalibrator2.
Returns
-------
bytes | None
The calibration cache contents if it exists
"""
if self._cache_path is None:
return None
if not self._cache_path.exists():
return None
with self._cache_path.open("rb") as f:
LOG.debug(f"Reading calibration cache file: {self._cache_path}")
data: bytes = f.read()
return data
[docs]
def write_calibration_cache(self: Self, cache: bytes) -> None:
"""
Write the calibration date to the calibration cache file.
Overrides from trt.IInt8EntropyCalibrator2.
Parameters
----------
cache : bytes
The calibration data generated.
"""
if self._cache_path is None:
return
with self._cache_path.open("wb") as f:
LOG.debug(f"Writing calibration cache file: {self._cache_path}")
f.write(cache)