|
|
@ -348,6 +348,7 @@ if 'GITHUB_ACTIONS' in os.environ:
|
|
|
|
'vgg*',
|
|
|
|
'vgg*',
|
|
|
|
'vit_large*',
|
|
|
|
'vit_large*',
|
|
|
|
'xcit_large*',
|
|
|
|
'xcit_large*',
|
|
|
|
|
|
|
|
'mixer_l*',
|
|
|
|
]
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -368,15 +369,16 @@ def test_model_forward_fx(model_name, batch_size):
|
|
|
|
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
|
|
|
|
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
|
|
|
|
if max(input_size) > MAX_FWD_FX_SIZE:
|
|
|
|
if max(input_size) > MAX_FWD_FX_SIZE:
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
inputs = torch.randn((batch_size, *input_size))
|
|
|
|
with torch.no_grad():
|
|
|
|
outputs = model(inputs)
|
|
|
|
inputs = torch.randn((batch_size, *input_size))
|
|
|
|
if isinstance(outputs, tuple):
|
|
|
|
outputs = model(inputs)
|
|
|
|
outputs = torch.cat(outputs)
|
|
|
|
if isinstance(outputs, tuple):
|
|
|
|
|
|
|
|
outputs = torch.cat(outputs)
|
|
|
|
|
|
|
|
|
|
|
|
model = _create_fx_model(model)
|
|
|
|
model = _create_fx_model(model)
|
|
|
|
fx_outputs = tuple(model(inputs).values())
|
|
|
|
fx_outputs = tuple(model(inputs).values())
|
|
|
|
if isinstance(fx_outputs, tuple):
|
|
|
|
if isinstance(fx_outputs, tuple):
|
|
|
|
fx_outputs = torch.cat(fx_outputs)
|
|
|
|
fx_outputs = torch.cat(fx_outputs)
|
|
|
|
|
|
|
|
|
|
|
|
assert torch.all(fx_outputs == outputs)
|
|
|
|
assert torch.all(fx_outputs == outputs)
|
|
|
|
assert outputs.shape[0] == batch_size
|
|
|
|
assert outputs.shape[0] == batch_size
|
|
|
@ -440,9 +442,10 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
|
|
|
|
model.eval()
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
model = torch.jit.script(_create_fx_model(model))
|
|
|
|
model = torch.jit.script(_create_fx_model(model))
|
|
|
|
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
|
|
|
with torch.no_grad():
|
|
|
|
if isinstance(outputs, tuple):
|
|
|
|
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
|
|
|
outputs = torch.cat(outputs)
|
|
|
|
if isinstance(outputs, tuple):
|
|
|
|
|
|
|
|
outputs = torch.cat(outputs)
|
|
|
|
|
|
|
|
|
|
|
|
assert outputs.shape[0] == batch_size
|
|
|
|
assert outputs.shape[0] == batch_size
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
|
|