Source code for trtutils.builder._progress
# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations
from typing import TYPE_CHECKING
from tqdm import tqdm
from trtutils.compat._libs import trt
if TYPE_CHECKING:
from typing_extensions import Self
[docs]
class ProgressBar(trt.IProgressMonitor):
"""A progress bar for building TensorRT engines."""
def __init__(self: Self) -> None:
"""Initialize the progress bar."""
super().__init__()
self._progress_bars: dict[str, tqdm] = {}
self._phase_parents: dict[str, str | None] = {} # Track parent relationships
self._indentation_levels: dict[
str,
int,
] = {} # Track indentation levels directly
self._last_steps: dict[str, int] = {} # Track last step for each phase
self._interrupted: bool = False
def __del__(self: Self) -> None:
for progress_bar in self._progress_bars.values():
progress_bar.close()
[docs]
def phase_start(
self: Self,
phase_name: str,
parent_phase: str | None,
num_steps: int,
) -> None:
"""
Start a new phase.
Parameters
----------
phase_name : str
The name of the phase.
parent_phase : str | None
The name of the parent phase, or None if the phase is a root phase.
num_steps : int
The number of steps in the phase.
"""
try:
# Store parent relationship
self._phase_parents[phase_name] = parent_phase
# Calculate indentation based on parent
current_indent = 0
if parent_phase is not None and parent_phase in self._indentation_levels:
current_indent = self._indentation_levels[parent_phase] + 1
self._indentation_levels[phase_name] = current_indent
self._last_steps[phase_name] = 0 # Initialize last step counter
# Create progress bar with indentation
indent = " " * current_indent
desc = f"{indent}{phase_name}"
self._progress_bars[phase_name] = tqdm(
total=num_steps,
desc=desc,
leave=False,
position=len(self._progress_bars),
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]",
)
except KeyboardInterrupt:
self._interrupted = True
[docs]
def step_complete(self: Self, phase_name: str, step: int) -> bool:
"""
Step in current phase is completed.
Parameters
----------
phase_name : str
The name of the phase.
step : int
The step number.
Returns
-------
bool
True if the build should continue, False if it should be interrupted.
"""
try:
if phase_name in self._progress_bars:
last_step = self._last_steps[phase_name]
step_diff = step - last_step
if step_diff > 0:
self._progress_bars[phase_name].update(step_diff)
self._last_steps[phase_name] = step
except KeyboardInterrupt:
self._interrupted = True
return False
return not self._interrupted
[docs]
def phase_finish(self: Self, phase_name: str) -> None:
"""Finish the current phase."""
try:
if phase_name in self._progress_bars:
self._progress_bars[phase_name].close()
del self._progress_bars[phase_name]
if phase_name in self._phase_parents:
del self._phase_parents[phase_name]
if phase_name in self._indentation_levels:
del self._indentation_levels[phase_name]
if phase_name in self._last_steps:
del self._last_steps[phase_name]
except KeyboardInterrupt:
self._interrupted = True