More FX test tweaks

pull/842/head
Ross Wightman 3 years ago
parent f0507f6da6
commit 1e51c2d02e

@ -348,6 +348,7 @@ if 'GITHUB_ACTIONS' in os.environ:
'vgg*', 'vgg*',
'vit_large*', 'vit_large*',
'xcit_large*', 'xcit_large*',
'mixer_l*',
] ]
@ -368,6 +369,7 @@ def test_model_forward_fx(model_name, batch_size):
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE) input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
if max(input_size) > MAX_FWD_FX_SIZE: if max(input_size) > MAX_FWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.") pytest.skip("Fixed input size model > limit.")
with torch.no_grad():
inputs = torch.randn((batch_size, *input_size)) inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs) outputs = model(inputs)
if isinstance(outputs, tuple): if isinstance(outputs, tuple):
@ -440,6 +442,7 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
model.eval() model.eval()
model = torch.jit.script(_create_fx_model(model)) model = torch.jit.script(_create_fx_model(model))
with torch.no_grad():
outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple): if isinstance(outputs, tuple):
outputs = torch.cat(outputs) outputs = torch.cat(outputs)

Loading…
Cancel
Save