New FX test strategy, filter based on param count

pull/842/head
Ross Wightman 3 years ago
parent 1e51c2d02e
commit ce76a810c2

@ -399,8 +399,10 @@ def test_model_backward_fx(model_name, batch_size):
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
num_params = sum([x.numel() for x in model.parameters()])
model.train()
num_params = sum([x.numel() for x in model.parameters()])
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
pytest.skip("Skipping FX backward test on model with more than 100M params.")
model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())

Loading…
Cancel
Save