Source code for trtutils.inspect._tensor
# Copyright (c) 2026 Justin Davis (davisjustin302@gmail.com)
#
# MIT License
from __future__ import annotations
from trtutils.compat._libs import trt
[docs]
def get_tensor_size(tensor: trt.ITensor) -> int:
"""
Calculate the size of a tensor in bytes.
Computes the total memory footprint by multiplying the number of elements
(derived from the tensor shape) by the per-element byte size of the dtype.
Dynamic dimensions (``-1``) are treated as ``1``.
Parameters
----------
tensor : trt.ITensor
The TensorRT tensor.
Returns
-------
int
Size in bytes.
"""
shape = tensor.shape
# Handle dynamic dimensions by assuming 1
num_elements = 1
for dim in shape:
num_elements *= max(1, dim)
# Get dtype size
dtype = tensor.dtype
dtype_sizes = {
trt.DataType.FLOAT: 4,
trt.DataType.HALF: 2,
trt.DataType.INT8: 1,
trt.DataType.INT32: 4,
trt.DataType.BOOL: 1,
trt.DataType.UINT8: 1,
}
dtype_size = dtype_sizes.get(dtype, 4)
return num_elements * dtype_size