From 69d725c9fe92efed62e954d31e95653fb091a797 Mon Sep 17 00:00:00 2001 From: michal Date: Thu, 7 May 2020 00:20:58 -0400 Subject: [PATCH] Basic forward pass test for all registered models --- tests/__init__.py | 0 tests/test_inference.py | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/test_inference.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 00000000..75b8d445 --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,19 @@ +import pytest +import torch + +from timm import list_models, create_model + + +@pytest.mark.timeout(60) +@pytest.mark.parametrize('model_name', list_models()) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward(model_name, batch_size): + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False) + model.eval() + + inputs = torch.randn((batch_size, *model.default_cfg['input_size'])) + outputs = model(inputs) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs'