Source code for trtutils._log
# Copyright (c) 2024-2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations
import logging
import os
import sys
from typing import TYPE_CHECKING, TextIO
from trtutils.compat._libs import trt
if TYPE_CHECKING:
from types import TracebackType
from typing_extensions import Self
_LEVEL_MAP: dict[str | None, int] = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARNING": logging.WARNING,
"WARN": logging.WARNING,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
None: logging.WARNING,
}
def _setup_logger(level: str | None = None) -> None:
if level is not None:
level = level.upper()
try:
log_level = _LEVEL_MAP[level]
except KeyError:
log_level = logging.WARNING
# create logger
logger = logging.getLogger(__package__)
logger.setLevel(log_level)
has_handler = False
if len(logger.handlers) > 0:
has_handler = True
if not has_handler:
formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
stream: TextIO = sys.stdout if sys.stdout is not None else sys.stderr
stdout_handler = logging.StreamHandler(stream=stream)
stdout_handler.setLevel(log_level)
stdout_handler.setFormatter(formatter)
logger.addHandler(stdout_handler)
else:
logger.handlers[0].setLevel(log_level)
logger.propagate = False
[docs]
def set_log_level(level: str) -> None:
"""
Set the log level for the trtutils package.
Parameters
----------
level : str
The log level to set. One of "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL".
Raises
------
ValueError
If the level is not one of the allowed values.
"""
if level.upper() not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
err_msg = f"Invalid log level: {level}"
raise ValueError(err_msg)
_setup_logger(level)
level = os.getenv("TRTUTILS_LOG_LEVEL")
_setup_logger(level)
_log = logging.getLogger(__name__)
if level is not None and level.upper() not in [
"DEBUG",
"INFO",
"WARNING",
"ERROR",
"CRITICAL",
]:
_log.warning(f"Invalid log level: {level}. Using default log level: WARNING")
# create a TensorRT compatible logger
class TRTLogger(trt.ILogger):
"""
Logger that implements TensorRT's ILogger interface while using Python's logging system.
This class bridges TensorRT's logging system with Python's standard logging module,
allowing TensorRT log messages to be handled by the Python logging framework.
It also provides convenience methods that match Python's standard logging levels.
Examples
--------
>>> logger = TRTLogger()
>>> logger.info("Starting TensorRT engine build")
>>> with trt.Builder(logger) as builder:
... # TensorRT will use the logger for its messages
... pass
"""
def __init__(self: Self) -> None:
"""
Initialize the TensorRT logger.
Creates a logger that implements TensorRT's ILogger interface and
delegates to a Python logging.Logger instance internally.
"""
super().__init__()
self._logger = logging.getLogger("trtutils")
self._level = self._logger.getEffectiveLevel()
@property
def logger(self: Self) -> logging.Logger:
"""
Get the internal Python logger.
Returns
-------
logging.Logger
The internal Python logger instance.
"""
return self._logger
@property
def level(self: Self) -> int:
"""
Get the current log level.
Returns
-------
int
The current log level of the logger.
"""
return self._level
class _LogLevelContext:
def __init__(self: Self, logger: TRTLogger, level: str | None) -> None:
self._logger = logger
self._level = level
self._old_level = logger.logger.getEffectiveLevel()
def __enter__(self: Self) -> TRTLogger:
if self._level:
lvl = self._level.upper()
log_level = _LEVEL_MAP.get(lvl)
if log_level is not None:
self._logger.logger.setLevel(log_level)
return self._logger
def __exit__(
self: Self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self._logger.logger.setLevel(self._old_level)
def with_level(self: Self, level: str | None) -> _LogLevelContext:
"""
Create a context manager to temporarily set the log level.
Parameters
----------
level : str | None
The log level to set. One of "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL".
Returns
-------
_LogLevelContext
A context manager that sets the log level for the duration of the block.
"""
return self._LogLevelContext(self, level)
def suppress(self: Self) -> _LogLevelContext:
"""
Suppress all log messages.
This method sets the logger's level to CRITICAL, effectively silencing all log messages.
Returns
-------
_LogLevelContext
A context manager that suppresses all log messages.
"""
return self._LogLevelContext(self, "CRITICAL")
def log(self: Self, severity: trt.ILogger.Severity, msg: str) -> None:
"""
Log a message with the specified severity.
This method implements TensorRT's ILogger.log method and maps TensorRT
severity levels to Python logging levels.
Parameters
----------
severity : trt.ILogger.Severity
TensorRT-specific severity level of the message
msg : str
The log message to record
"""
if severity == trt.ILogger.Severity.INFO:
self._logger.info(msg)
elif severity == trt.ILogger.Severity.WARNING:
self._logger.warning(msg)
elif severity == trt.ILogger.Severity.ERROR:
self._logger.error(msg)
else:
self._logger.debug(msg)
def debug(self: Self, msg: str) -> None:
"""
Log a debug message.
Parameters
----------
msg : str
The debug message to log
"""
self._logger.debug(msg)
def info(self: Self, msg: str) -> None:
"""
Log an info message.
Parameters
----------
msg : str
The info message to log
"""
self._logger.info(msg)
def warning(self: Self, msg: str) -> None:
"""
Log a warning message.
Parameters
----------
msg : str
The warning message to log
"""
self._logger.warning(msg)
def error(self: Self, msg: str) -> None:
"""
Log an error message.
Parameters
----------
msg : str
The error message to log
"""
self._logger.error(msg)
def critical(self: Self, msg: str) -> None:
"""
Log a critical message.
Parameters
----------
msg : str
The critical message to log
"""
self._logger.critical(msg)
LOG = TRTLogger()