Source code for trtutils.core._graph

# Copyright (c) 2025-2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING

import nvtx

from trtutils._flags import FLAGS
from trtutils._log import LOG
from trtutils.compat._libs import cudart

if TYPE_CHECKING:
    from types import TracebackType

from ._cuda import cuda_call

if TYPE_CHECKING:
    from typing_extensions import Self


[docs] def cuda_stream_begin_capture( stream: cudart.cudaStream_t, mode: cudart.cudaStreamCaptureMode | None = None, ) -> None: """ Begin capturing a CUDA graph on the given stream. Parameters ---------- stream : cudart.cudaStream_t The CUDA stream to begin capture on. mode : cudart.cudaStreamCaptureMode, optional The capture mode to use. Default is ThreadLocal, which only checks CUDA calls from the capturing thread. Global mode would cause any uncapturable call in any thread to fail during capture. """ if mode is None: mode = cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal cuda_call(cudart.cudaStreamBeginCapture(stream, mode))
[docs] def cuda_stream_end_capture(stream: cudart.cudaStream_t) -> cudart.cudaGraph_t: """ End capturing a CUDA graph and return the captured graph. Parameters ---------- stream : cudart.cudaStream_t The CUDA stream to end capture on. Returns ------- cudart.cudaGraph_t The captured CUDA graph. """ return cuda_call(cudart.cudaStreamEndCapture(stream))
[docs] def cuda_graph_instantiate( graph: cudart.cudaGraph_t, flags: int = 0, ) -> cudart.cudaGraphExec_t: """ Instantiate a CUDA graph executable. Parameters ---------- graph : cudart.cudaGraph_t The CUDA graph to instantiate. flags : int, optional Flags for graph instantiation. Default is 0. Returns ------- cudart.cudaGraphExec_t The instantiated graph executable. """ if FLAGS.CUDA_11: return cuda_call(cudart.cudaGraphInstantiate(graph, b"", 0)) return cuda_call(cudart.cudaGraphInstantiate(graph, flags))
[docs] def cuda_graph_launch( graph_exec: cudart.cudaGraphExec_t, stream: cudart.cudaStream_t, ) -> None: """ Launch a CUDA graph executable. Parameters ---------- graph_exec : cudart.cudaGraphExec_t The graph executable to launch. stream : cudart.cudaStream_t The CUDA stream to launch on. """ cuda_call(cudart.cudaGraphLaunch(graph_exec, stream))
[docs] def cuda_graph_destroy(graph: cudart.cudaGraph_t) -> None: """ Destroy a CUDA graph. Parameters ---------- graph : cudart.cudaGraph_t The CUDA graph to destroy. """ cuda_call(cudart.cudaGraphDestroy(graph))
[docs] def cuda_graph_exec_destroy(graph_exec: cudart.cudaGraphExec_t) -> None: """ Destroy a CUDA graph executable. Parameters ---------- graph_exec : cudart.cudaGraphExec_t The graph executable to destroy. """ cuda_call(cudart.cudaGraphExecDestroy(graph_exec))
[docs] class CUDAGraph: """Wrapper around CUDA graph capture and execution.""" def __init__(self: Self, stream: cudart.cudaStream_t) -> None: """ Initialize the CUDA graph helper. Parameters ---------- stream : cudart.cudaStream_t The CUDA stream to use for graph operations. """ self._stream = stream self._graph: cudart.cudaGraph_t | None = None self._graph_exec: cudart.cudaGraphExec_t | None = None def __enter__(self: Self) -> Self: self.start() return self def __exit__( self: Self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool: success = self.stop() # __exit__ returns False on success, kind of funky return not success def __del__(self: Self) -> None: self.invalidate()
[docs] def start(self: Self) -> None: """ Begin graph capture. This should be called before the operations to capture. """ if FLAGS.NVTX_ENABLED: nvtx.push_range("cuda_graph::start") cuda_stream_begin_capture(self._stream) if FLAGS.NVTX_ENABLED: nvtx.pop_range()
[docs] def stop(self: Self) -> bool: """ End graph capture and instantiate the graph. Returns ------- bool True if capture and instantiation succeeded, False otherwise. """ if FLAGS.NVTX_ENABLED: nvtx.push_range("cuda_graph::stop") try: self._graph = cuda_stream_end_capture(self._stream) self._graph_exec = cuda_graph_instantiate(self._graph, 0) except RuntimeError as e: err_str = str(e) if "cudaErrorStreamCapture" in err_str or "StreamCapture" in err_str: LOG.warning( f"CUDA graph capture failed (engine may not support graphs): {err_str}", ) else: LOG.warning(f"CUDA graph capture failed: {err_str}") self.invalidate() if FLAGS.NVTX_ENABLED: nvtx.pop_range() return False else: if FLAGS.NVTX_ENABLED: nvtx.pop_range() return True
[docs] def launch(self: Self) -> None: """ Launch the captured graph. Raises ------ RuntimeError If no graph has been captured. """ if FLAGS.NVTX_ENABLED: nvtx.push_range("cuda_graph::launch") if self._graph_exec is None: err_msg = "Cannot launch graph: no graph has been captured" if FLAGS.NVTX_ENABLED: nvtx.pop_range() raise RuntimeError(err_msg) cuda_graph_launch(self._graph_exec, self._stream) if FLAGS.NVTX_ENABLED: nvtx.pop_range()
[docs] def invalidate(self: Self) -> None: """Destroy the graph and graph executable, resetting state.""" with contextlib.suppress(AttributeError, RuntimeError): cuda_graph_exec_destroy(self._graph_exec) self._graph_exec = None with contextlib.suppress(AttributeError, RuntimeError): cuda_graph_destroy(self._graph) self._graph = None
@property def is_captured(self: Self) -> bool: """ Check if a graph has been captured. Returns ------- bool True if a graph has been captured, False otherwise. """ return self._graph_exec is not None