Remove FX backward test from GitHub actions runs for now.

pull/1007/head
Ross Wightman 3 years ago
parent 878bee1d5e
commit 147e1059a8

@ -386,11 +386,14 @@ def test_model_forward_fx(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
if 'GITHUB_ACTIONS' not in os.environ:
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
@ -418,6 +421,7 @@ def test_model_backward_fx(model_name, batch_size):
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224',

Loading…
Cancel
Save