From 1e51c2d02e77373e4e4248f5d826f519379ebdff Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 22 Nov 2021 09:46:43 -0800 Subject: [PATCH] More FX test tweaks --- tests/test_models.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 68939a14..c3642eb9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -348,6 +348,7 @@ if 'GITHUB_ACTIONS' in os.environ: 'vgg*', 'vit_large*', 'xcit_large*', + 'mixer_l*', ] @@ -368,15 +369,16 @@ def test_model_forward_fx(model_name, batch_size): input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE) if max(input_size) > MAX_FWD_FX_SIZE: pytest.skip("Fixed input size model > limit.") - inputs = torch.randn((batch_size, *input_size)) - outputs = model(inputs) - if isinstance(outputs, tuple): - outputs = torch.cat(outputs) + with torch.no_grad(): + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) - model = _create_fx_model(model) - fx_outputs = tuple(model(inputs).values()) - if isinstance(fx_outputs, tuple): - fx_outputs = torch.cat(fx_outputs) + model = _create_fx_model(model) + fx_outputs = tuple(model(inputs).values()) + if isinstance(fx_outputs, tuple): + fx_outputs = torch.cat(fx_outputs) assert torch.all(fx_outputs == outputs) assert outputs.shape[0] == batch_size @@ -440,9 +442,10 @@ def test_model_forward_fx_torchscript(model_name, batch_size): model.eval() model = torch.jit.script(_create_fx_model(model)) - outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) - if isinstance(outputs, tuple): - outputs = torch.cat(outputs) + with torch.no_grad(): + outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs'