diff --git a/tests/test_models.py b/tests/test_models.py index c3642eb9..9fb826c5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -399,8 +399,10 @@ def test_model_backward_fx(model_name, batch_size): pytest.skip("Fixed input size model > limit.") model = create_model(model_name, pretrained=False, num_classes=42) - num_params = sum([x.numel() for x in model.parameters()]) model.train() + num_params = sum([x.numel() for x in model.parameters()]) + if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6: + pytest.skip("Skipping FX backward test on model with more than 100M params.") model = _create_fx_model(model, train=True) outputs = tuple(model(torch.randn((batch_size, *input_size))).values())