Source code for trtutils.builder._batcher

# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
from __future__ import annotations

import atexit
import concurrent.futures
import contextlib
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Empty, Full, Queue
from threading import Event, Thread
from typing import TYPE_CHECKING

import cv2
import numpy as np
from cv2ext.image import letterbox, rescale

from trtutils._log import LOG

if TYPE_CHECKING:
    from typing_extensions import Self


class AbstractBatcher(ABC):
    """Abstract base class for data batching classes."""

    @abstractmethod
    def __init__(self: Self) -> None:
        """Initialize the batcher."""

    @property
    @abstractmethod
    def num_batches(self: Self) -> int:
        """Get the number of batches."""

    @property
    @abstractmethod
    def batch_size(self: Self) -> int:
        """Get the batch size."""

    @abstractmethod
    def get_next_batch(self: Self) -> np.ndarray | None:
        """Get the batch of data."""


[docs] class ImageBatcher(AbstractBatcher): """Creates image batches for calibrating TensorRT engines.""" def __init__( self: Self, image_dir: Path | str, shape: tuple[int, int, int], dtype: np.dtype, batch_size: int = 8, order: str = "NCHW", max_images: int | None = None, resize_method: str = "letterbox", input_scale: tuple[float, float] = (0.0, 1.0), *, verbose: bool | None = None, ) -> None: """ Create batches of images for TensorRT calibration. Parameters ---------- image_dir : Path, str The directory containing the images to calibrate the model with. shape : tuple[int, int, int] The expected input shape of the network in format HWC (height, width, channels) dtype : np.dtype The expected datatype input of the network batch_size : int, optional The batch size to group images in. Default is 8 order : str, optional The ordering of data elements expected by the network. Options are: ['NCHW', 'NHWC'], default is 'NCHW' max_images : int, optional Optionally, specify the maximum number of images to calibrate with resize_method : str, optional The method by which to resize the images Options are: ['letterbox', 'linear'], default is 'letterbox' input_scale : tuple[float, float], optional The range with which the image should have values in Examples are: (0.0, 255.0), (0.0, 1.0), (-1.0, 1.0) The default is (0.0, 1.0) verbose : bool, optional Whether to print verbose output, by default None Raises ------ FileNotFoundError If the input_dir cannot be found NotADirectoryError If the input_dir is a file ValueError If the max_images is set less than one ValueError If no images could be found in the input_dir ValueError If no valid batches could be formed """ self._verbose = verbose # verify resize method and input scale valid_resize_methods = ["letterbox", "linear"] if resize_method not in valid_resize_methods: err_msg = ( f"Invalid resize method found, options are: {valid_resize_methods}" ) raise ValueError(err_msg) self._resize_method = resize_method self._input_scale = input_scale # handle the shape and dtype self._dtype = dtype self._batch = batch_size self._height = shape[0] self._width = shape[1] self._channel = shape[2] # assign order valid_orders: list[str] = ["NCHW", "NHWC"] if order not in valid_orders: err_msg = f"Invalid order found, {order}, options are: {valid_orders}" raise ValueError(err_msg) self._order = order self._data_shape: tuple[int, int, int, int] = ( self._batch, self._channel, self._height, self._width, ) if order == "NHWC": self._data_shape = (self._batch, self._height, self._width, self._channel) # verify image directory image_dir_path: Path = Path(image_dir) if not image_dir_path.exists(): err_msg = f"Could not find image directory: {image_dir}" raise FileNotFoundError(err_msg) if image_dir_path.is_file(): err_msg = f"Image directory: {image_dir} is a file." raise NotADirectoryError(err_msg) # get the paths of all images image_exts: list[str] = [".jpg", ".jpeg", ".png"] self._images: list[Path] = [] for extension in image_exts: for file in image_dir_path.glob(f"*{extension}"): self._images.append(file) # ensure we found some images if len(self._images) == 0: err_msg = f"Could not find any images in directory: {image_dir}" raise ValueError(err_msg) # make the list deterministic self._images.sort() # trim to max_images if max_images is not None: if not max_images > 1: err_msg = f"max_images must be greated than 1, found: {max_images}" raise ValueError(err_msg) self._images = self._images[0:max_images] # generate batches list self._batches: list[list[Path]] = [] for i in range(0, len(self._images), self._batch): with contextlib.suppress(IndexError): batch: list[Path] = [self._images[i + j] for j in range(self._batch)] self._batches.append(batch) if len(self._batches) == 0: err_msg = "Could not form any valid batches." raise ValueError(err_msg) LOG.debug(f"ImageBatcher found images: {len(self._images)}") LOG.debug(f"ImageBatcher formed batches: {len(self._batches)}") # tracking indices for iteration self._current_batch: int = 0 # threading setup to speedup loading cpu_cores = os.cpu_count() if cpu_cores is None: cpu_cores = 1 self._pool = concurrent.futures.ThreadPoolExecutor( max_workers=min(self._batch, cpu_cores - 1), ) self._event: Event = Event() self._queue: Queue[np.ndarray] = Queue(maxsize=3) self._thread: Thread = Thread(target=self._run, daemon=True) atexit.register(self._close) self._thread.start() @property def num_batches(self: Self) -> int: """Get the number of batches.""" return len(self._batches) @property def batch_size(self: Self) -> int: """Get the batch size.""" return self._batch def _close(self: Self) -> None: self._event.set() self._thread.join() def _get_image(self: Self, image_path: Path) -> np.ndarray: # read image img = cv2.imread(str(image_path.resolve())) # resize and rescale the image if self._resize_method == "letterbox": resized_img, _, _ = letterbox(img, (self._width, self._height)) else: resized_img = cv2.resize( img, (self._width, self._height), interpolation=cv2.INTER_LINEAR, ) rescaled_img = rescale(resized_img, self._input_scale) # run the transpose operations and make contiguous new_img = rescaled_img[np.newaxis, :] if self._order == "NCHW": new_img = np.transpose(new_img, (0, 3, 1, 2)) else: new_img = np.transpose(new_img, (0, 1, 2, 3)) new_img = new_img.astype(self._dtype) if not new_img.flags["C_CONTIGUOUS"]: new_img = np.ascontiguousarray(new_img) return new_img def _run(self: Self) -> None: # for each batch get the images for idx, image_paths in enumerate(self._batches): if self._event.is_set(): return LOG.debug(f"ImageBatcher getting batch: {idx}") # get the batch results = list(self._pool.map(self._get_image, image_paths)) data = np.zeros(self._data_shape, dtype=self._dtype) for i, img in enumerate(results): data[i] = img if not data.flags["C_CONTIGUOUS"]: data = np.ascontiguousarray(data) # keep trying to put images into the queue while not self._event.is_set(): try: self._queue.put(data, timeout=0.1) if self._verbose: LOG.debug( f"ImageBatcher put batch: {idx} / {len(self._batches)}" ) break except Full: continue
[docs] def get_next_batch(self: Self) -> np.ndarray | None: """ Get a batch of images which have been preprocessed. Returns ------- np.ndarray | None The batch of images if one exists """ if self._current_batch == len(self._batches): return None while not self._event.is_set(): with contextlib.suppress(Empty): batch = self._queue.get(timeout=0.1) self._current_batch += 1 if self._verbose: LOG.debug( f"ImageBatcher get batch: {self._current_batch} / {len(self._batches)}" ) return batch return None