Merge pull request #1354 from rwightman/fix_tests

Attempting to fix unit test failures...
pull/1363/head
Ross Wightman 2 years ago committed by GitHub
commit 4547920f85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,6 +4,8 @@ import platform
import os import os
import fnmatch import fnmatch
_IS_MAC = platform.system() == 'Darwin'
try: try:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
has_fx_feature_extraction = True has_fx_feature_extraction = True
@ -322,7 +324,10 @@ def test_model_forward_features(model_name, batch_size):
assert not torch.isnan(o).any() assert not torch.isnan(o).any()
def _create_fx_model(model, train=False): if not _IS_MAC:
# MACOS test runners are really slow, only running tests below this point if not on a Darwin runner...
def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
@ -354,9 +359,9 @@ def _create_fx_model(model, train=False):
return fx_model return fx_model
EXCLUDE_FX_FILTERS = ['vit_gi*'] EXCLUDE_FX_FILTERS = ['vit_gi*']
# not enough memory to run fx on more models than other tests # not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ: if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [ EXCLUDE_FX_FILTERS += [
'beit_large*', 'beit_large*',
'mixer_l*', 'mixer_l*',
@ -373,10 +378,10 @@ if 'GITHUB_ACTIONS' in os.environ:
] ]
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS)) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size): def test_model_forward_fx(model_name, batch_size):
""" """
Symbolically trace each model and run single forward pass through the resulting GraphModule Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
@ -406,11 +411,11 @@ def test_model_forward_fx(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs' assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models( @pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True)) exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2]) @pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size): def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule""" """Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction: if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
@ -439,7 +444,7 @@ def test_model_backward_fx(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs' assert not torch.isnan(outputs).any(), 'Output included NaNs'
if 'GITHUB_ACTIONS' not in os.environ: if 'GITHUB_ACTIONS' not in os.environ:
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process # FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow

Loading…
Cancel
Save