Source code for trtutils._nvtx

# Copyright (c) 2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
from __future__ import annotations

from typing import TYPE_CHECKING

import nvtx

from ._flags import FLAGS

if TYPE_CHECKING:
    from types import TracebackType

    from typing_extensions import Self


[docs] def enable_nvtx() -> None: """Enable trtutils NVTX profiling.""" FLAGS.NVTX_ENABLED = True
[docs] def disable_nvtx() -> None: """Disable trtutils NVTX profiling.""" FLAGS.NVTX_ENABLED = False
[docs] class NVTX: """Context manager and static helpers for trtutils NVTX profiling.""" def __init__(self: Self, name: str) -> None: """ Initialize trtutils NVTX context manager. Parameters ---------- name : str The name of the NVTX scope. """ self._name = name self._pre_enabled = False def __enter__(self: Self) -> None: """Enter the NVTX context manager.""" if FLAGS.NVTX_ENABLED: self._pre_enabled = True else: enable_nvtx() nvtx.push_range(self._name) def __exit__( self: Self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: """Exit the NVTX context manager.""" if not self._pre_enabled: disable_nvtx() nvtx.pop_range()