GitHub - lengstrom/tensorguard

2 min read Original article ↗

tensorguard

Pretty runtime typechecking for PyTorch and Numpy tensors!

Install

git clone git@github.com:lengstrom/tensorguard.git
pip install -e tensorguard

Example usage

As a decorator:

from tensorguard import tensorguard, Tensor as T
import torch as ch

@tensorguard
def inference(x: T(['bs', 3, 224, 224], 'float16', 'cpu'), y: T(['bs'], 'int64')):
    pass

# make examples with wrong dtype
x = ch.ones(128, 3, 224, 224, dtype=ch.float32)
# make labels with wrong batch size
y = ch.ones(256)

# checks happen at runtime with @tensorguard decorator
inference(x, y)

As a standalone assertion:

from tensorguard import tensorcheck
x = ch.randn(4, 4).to(dtype=ch.float32)
x_expected = Tensor([4, 4])

# check one at once
tensorcheck(x, x_expected)

# or multiple...
tensorcheck([x, y], [x_expected, y_expected])

Not specifying or setting a field to None yields a wildcard type; by default, every field is None. You can also check that the tensor type is either 'numpy' or 'pytorch'!

tensorcheck(x, Tensor([4, None], library='numpy', device=None))

Related work

TODOs:

  • use a different color for each individual error found in the runtime type checking