# Copyright (c) 2024-2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
"""
Tools for managing the trtutils TensorRT engine cache.
Useful for quickly recalling pre-compiled TRT engines,
without having to implement your own caching mechanisms.
Used in the TRTPreprocessor to manage compiled engines
for different image sizes.
Functions
---------
:func:`get_cache_dir`
Gets the cache directory inside of the trtutils install.
:func:`clear`
Clears the cache directory.
:func:`query`
Queries the cache to see if an engine with that name already exists.
:func:`store`
Stores a compiled TensorRT engine in the cache.
:func:`remove`
Removes an engine file from the cache.
:func:`query_file`
Queries the cache for a file with a specific extension.
:func:`store_file`
Stores a file in the cache with a specific name.
:func:`remove_file`
Removes a file from the cache.
:func:`query_timing_cache`
Queries the cache for the global timing cache.
:func:`store_timing_cache`
Stores the global timing cache in the cache directory.
:func:`save_timing_cache_to_global`
Saves a TensorRT timing cache object directly to the global timing cache.
"""
# POTENTIAL CHANGE: Update to use platformdirs behind the scenes
# https://github.com/tox-dev/platformdirs/tree/main?tab=readme-ov-file
from __future__ import annotations
import shutil
import tempfile
from pathlib import Path
from typing import Protocol
from trtutils._log import LOG
class _TimingCache(Protocol):
def serialize(self) -> bytes: ...
# Known valid cache file extensions
_VALID_EXTENSIONS = {"engine", "onnx", "cache"}
def _get_cache_file_path(filename: str, extension: str) -> Path:
"""Get the full path for a cache file, handling extension logic."""
has_valid_extension = False
if "." in filename:
ext = filename.rsplit(".", 1)[-1]
has_valid_extension = ext in _VALID_EXTENSIONS
if has_valid_extension:
return get_cache_dir() / filename
return get_cache_dir() / f"{filename}.{extension}"
def _delete_folder(directory: Path) -> None:
for item in directory.iterdir():
if item.is_dir():
_delete_folder(item)
else:
item.unlink()
directory.rmdir()
[docs]
def get_cache_dir() -> Path:
"""
Get the location of the trtutils engine cache directory.
Returns
-------
Path
The trtutils engine cache directory Path
"""
file_path = Path(__file__)
return file_path.parent / "_engine_cache"
[docs]
def clear(*, no_warn: bool | None = None) -> None:
"""
Use to clear the cache folder for the trtutils engines.
Parameters
----------
no_warn : bool, optional
Whether or not to issue a warning that the cache directory
is being cleared.
"""
if not no_warn:
LOG.warning("Engine cache is being cleared")
cache_dir = get_cache_dir()
_delete_folder(cache_dir)
cache_dir.mkdir()
[docs]
def query_file(filename: str, extension: str = "engine") -> tuple[bool, Path]:
"""
Check if a file with the given name and extension is present in the cache.
Parameters
----------
filename : str
The filename to check for. Can be with or without extension.
If extension is provided in filename, it will be used.
extension : str, optional
The file extension to use (without the dot).
By default, "engine".
Returns
-------
tuple[bool, Path]
Whether or not the file exists and its Path (whether or not it exists)
"""
file_path = _get_cache_file_path(filename, extension)
return file_path.exists(), file_path
[docs]
def query(filename: str) -> tuple[bool, Path]:
"""
Check if the engine filename is present in the cache.
Parameters
----------
filename : str
The filename to check for without a suffix.
Returns
-------
tuple[bool, Path]
Whether or not the file exists and its Path (whether or not it exists)
"""
return query_file(filename, extension="engine")
[docs]
def store_file(
filepath: Path,
cache_filename: str | None = None,
*,
overwrite: bool = False,
delete_source: bool = False,
) -> Path:
"""
Store a file in the trtutils cache.
Parameters
----------
filepath : Path
The path to the file to store in the cache.
cache_filename : str, optional
The name to use in the cache. If None, uses the original filename.
By default, None.
overwrite : bool, optional
Whether or not to overwrite an existing file with the same name.
By default False, will keep the older version.
delete_source : bool, optional
Whether or not to delete the source file after storing.
By default, False.
Returns
-------
Path
The new path of the file in the cache.
"""
if cache_filename is None:
cache_filename = filepath.name
new_file_path = get_cache_dir() / cache_filename
exists = new_file_path.exists()
if not overwrite and exists:
if delete_source:
filepath.unlink()
return new_file_path
# otherwise we write the file
get_cache_dir().mkdir(parents=True, exist_ok=True)
shutil.copy(filepath, new_file_path)
if delete_source:
filepath.unlink()
return new_file_path
[docs]
def store(filepath: Path, *, overwrite: bool = False, delete_source: bool = False) -> Path:
"""
Store an engine file in the trtutils engine cache.
Parameters
----------
filepath : Path
The path to the engine file to store in the cache.
overwrite : bool, optional
Whether or not to overwrite an existing file with the same name.
By default False, will keep the older version.
delete_source : bool, optional
Whether or not to delete the source file after storing.
By default, False.
Returns
-------
Path
The new path of the file in the cache.
"""
return store_file(filepath, overwrite=overwrite, delete_source=delete_source)
[docs]
def remove_file(filename: str, extension: str = "engine") -> None:
"""
Remove a file from the cache.
Parameters
----------
filename : str
The filename to remove from the cache. Can be with or without extension.
If extension is provided in filename, it will be used.
extension : str, optional
The file extension to use (without the dot).
By default, "engine".
Raises
------
FileNotFoundError
If the file does not exist in the cache.
"""
file_path = _get_cache_file_path(filename, extension)
if not file_path.exists():
err_msg = f"File {file_path} does not exist in the cache"
raise FileNotFoundError(err_msg)
file_path.unlink()
[docs]
def remove(filename: str) -> None:
"""
Remove an engine file from the cache.
Parameters
----------
filename : str
The filename to remove from the cache.
"""
remove_file(filename, extension="engine")
[docs]
def query_timing_cache() -> tuple[bool, Path]:
"""
Query the cache for the global timing cache.
Returns
-------
tuple[bool, Path]
Whether or not the global timing cache exists and its Path.
"""
return query_file("global", extension="cache")
[docs]
def store_timing_cache(
filepath: Path, *, overwrite: bool = False, delete_source: bool = False
) -> Path:
"""
Store the global timing cache in the cache directory.
Parameters
----------
filepath : Path
The path to the timing cache file to store.
overwrite : bool, optional
Whether or not to overwrite an existing global timing cache.
By default False, will keep the older version.
delete_source : bool, optional
Whether or not to delete the source file after storing.
By default, False.
Returns
-------
Path
The path of the global timing cache in the cache directory.
"""
return store_file(
filepath, cache_filename="global.cache", overwrite=overwrite, delete_source=delete_source
)
[docs]
def save_timing_cache_to_global(timing_cache_obj: _TimingCache, *, overwrite: bool = True) -> Path:
"""
Save a TensorRT timing cache object to the global timing cache.
Parameters
----------
timing_cache_obj
The TensorRT timing cache object (from config.get_timing_cache()).
overwrite : bool, optional
Whether or not to overwrite an existing global timing cache.
By default True.
Returns
-------
Path
The path of the global timing cache in the cache directory.
"""
serialized_cache = memoryview(timing_cache_obj.serialize())
# create a temporary file to store the serialized cache
with tempfile.NamedTemporaryFile(delete=False, suffix=".cache") as tmp_file:
tmp_path = Path(tmp_file.name)
tmp_path.write_bytes(serialized_cache)
return store_timing_cache(tmp_path, overwrite=overwrite, delete_source=True)