diff --git a/tests/test_inference.py b/tests/test_inference.py index 55bafb21..2490a0bc 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -8,12 +8,12 @@ from timm import list_models, create_model @pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*')) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): - """Run a single forward pass with each model""" - model = create_model(model_name, pretrained=False) - model.eval() + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False) + model.eval() - inputs = torch.randn((batch_size, *model.default_cfg['input_size'])) - outputs = model(inputs) + inputs = torch.randn((batch_size, *model.default_cfg['input_size'])) + outputs = model(inputs) - assert outputs.shape[0] == batch_size - assert not torch.isnan(outputs).any(), 'Output included NaNs' + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs'