Source code for trtutils.core._stream

# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations

import contextlib

import nvtx

with contextlib.suppress(Exception):
    from trtutils.compat._libs import cudart

from trtutils._flags import FLAGS

from ._cuda import cuda_call


[docs] def create_stream() -> cudart.cudaStream_t: """ Create a CUDA Stream. Returns ------- cudart.cudaStream_t The CUDA stream. """ if FLAGS.NVTX_ENABLED: nvtx.push_range("core::create_stream") result = cuda_call(cudart.cudaStreamCreate()) if FLAGS.NVTX_ENABLED: nvtx.pop_range() return result
[docs] def destroy_stream(stream: cudart.cudaStream_t) -> None: """ Destroy a CUDA Stream. Parameters ---------- stream : cudart.cudaStream_t The CUDA stream to destroy. """ cuda_call(cudart.cudaStreamDestroy(stream))
[docs] def stream_synchronize(stream: cudart.cudaStream_t) -> None: """ Copy a numpy array to a device pointer with error checking. Parameters ---------- stream : cudart.cudaStream_t The stream to synchronize calls for. """ if FLAGS.NVTX_ENABLED: nvtx.push_range("core::stream_synchronize") cuda_call(cudart.cudaStreamSynchronize(stream)) if FLAGS.NVTX_ENABLED: nvtx.pop_range()