diff --git a/tests/test_models.py b/tests/test_models.py index 1750d540..5fde43da 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -334,8 +334,14 @@ def _create_fx_model(model, train=False): return fx_model +EXCLUDE_FX_FILTERS = [] +# not enough memory to run fx on more models than other tests +if 'GITHUB_ACTIONS' in os.environ: + EXCLUDE_FX_FILTERS += ['beit_large*', 'swin_large*'] + + @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx(model_name, batch_size): """ @@ -367,7 +373,8 @@ def test_model_forward_fx(model_name, batch_size): @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) +@pytest.mark.parametrize('model_name', list_models( + exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [2]) def test_model_backward_fx(model_name, batch_size): """Symbolically trace each model and run single backward pass through the resulting GraphModule""" @@ -400,7 +407,7 @@ EXCLUDE_FX_JIT_FILTERS = [ 'deit_*_distilled_patch16_224', 'levit*', 'pit_*_distilled_224', -] +] + EXCLUDE_FX_FILTERS @pytest.mark.timeout(120)