diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 00000000..75b8d445 --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,19 @@ +import pytest +import torch + +from timm import list_models, create_model + + +@pytest.mark.timeout(60) +@pytest.mark.parametrize('model_name', list_models()) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward(model_name, batch_size): + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False) + model.eval() + + inputs = torch.randn((batch_size, *model.default_cfg['input_size'])) + outputs = model(inputs) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs'