Source code for trtutils.core._device
# Copyright (c) 2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
from __future__ import annotations
from typing import TYPE_CHECKING
from trtutils.compat._libs import cudart
from ._cuda import cuda_call
if TYPE_CHECKING:
from typing_extensions import Self
_SM_ARCH_MAP: dict[int, str | dict[int, str]] = {
3: "kepler",
5: "maxwell",
6: "pascal",
7: {0: "volta", 1: "volta", 2: "volta", 5: "turing"},
8: {0: "ampere", 6: "ampere", 7: "ampere", 9: "ada"},
9: "hopper",
10: "blackwell",
12: "blackwell",
}
[docs]
def get_sm_arch(major: int, minor: int) -> str:
"""
Get the GPU architecture name from a compute capability version.
Parameters
----------
major : int
The major compute capability version.
minor : int
The minor compute capability version.
Returns
-------
str
The architecture name (e.g. "turing", "blackwell").
Returns "unknown" if the compute capability is not recognized.
"""
if major == 0:
return "unknown"
entry = _SM_ARCH_MAP.get(major)
if entry is None:
return "unknown"
if isinstance(entry, str):
return entry
return entry.get(minor, "unknown")
[docs]
def get_device_name(device: int = 0) -> str:
"""
Get the name of a CUDA device.
Parameters
----------
device : int, optional
The CUDA device index. Default is 0.
Returns
-------
str
The device name (e.g. "NVIDIA GeForce RTX 5080").
"""
props = cuda_call(cudart.cudaGetDeviceProperties(device))
name = props.name
return name.decode() if isinstance(name, bytes) else name
[docs]
def get_compute_capability(device: int = 0) -> tuple[int, int]:
"""
Get the compute capability (SM version) of a CUDA device.
Parameters
----------
device : int, optional
The CUDA device index. Default is 0.
Returns
-------
tuple[int, int]
A tuple of (major, minor) compute capability version.
"""
props = cuda_call(cudart.cudaGetDeviceProperties(device))
return (props.major, props.minor)
[docs]
def get_device() -> int:
"""
Get the current CUDA device.
Returns
-------
int
The current CUDA device index.
"""
return cuda_call(cudart.cudaGetDevice())
[docs]
def set_device(device: int) -> None:
"""
Set the current CUDA device.
Parameters
----------
device : int
The CUDA device index to set.
"""
cuda_call(cudart.cudaSetDevice(device))
[docs]
def get_device_count() -> int:
"""
Get the number of CUDA devices available.
Returns
-------
int
The number of CUDA devices.
"""
return cuda_call(cudart.cudaGetDeviceCount())
[docs]
class Device:
"""
Context manager that saves and restores the current CUDA device.
When ``device`` is ``None`` the guard is a no-op: ``__enter__`` and
``__exit__`` only check a single attribute, adding negligible overhead
on the hot path.
Instances are **reusable** — engines store one as ``self._device_guard``
and enter/exit it on every ``execute()`` call.
"""
__slots__ = ("_device", "_previous")
def __init__(self, device: int | None) -> None:
self._device = device
self._previous: int = 0
def __enter__(self: Self) -> Self:
if self._device is not None:
self._previous = get_device()
set_device(self._device)
return self
def __exit__(self, *args: object) -> None:
if self._device is not None:
set_device(self._previous)