Source code for trtutils.core._stream
# Copyright (c) 2024 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
# mypy: disable-error-code="import-untyped"
from __future__ import annotations
import contextlib
import nvtx
with contextlib.suppress(Exception):
from trtutils.compat._libs import cudart
from trtutils._flags import FLAGS
from ._cuda import cuda_call
[docs]
def create_stream() -> cudart.cudaStream_t:
"""
Create a CUDA Stream.
Returns
-------
cudart.cudaStream_t
The CUDA stream.
"""
if FLAGS.NVTX_ENABLED:
nvtx.push_range("core::create_stream")
result = cuda_call(cudart.cudaStreamCreate())
if FLAGS.NVTX_ENABLED:
nvtx.pop_range()
return result
[docs]
def destroy_stream(stream: cudart.cudaStream_t) -> None:
"""
Destroy a CUDA Stream.
Parameters
----------
stream : cudart.cudaStream_t
The CUDA stream to destroy.
"""
cuda_call(cudart.cudaStreamDestroy(stream))
[docs]
def stream_synchronize(stream: cudart.cudaStream_t) -> None:
"""
Copy a numpy array to a device pointer with error checking.
Parameters
----------
stream : cudart.cudaStream_t
The stream to synchronize calls for.
"""
if FLAGS.NVTX_ENABLED:
nvtx.push_range("core::stream_synchronize")
cuda_call(cudart.cudaStreamSynchronize(stream))
if FLAGS.NVTX_ENABLED:
nvtx.pop_range()