diff --git a/tests/test_models.py b/tests/test_models.py index f7233ef3..e513dcaf 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,7 +4,11 @@ import platform import os import fnmatch -from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer +try: + from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer + has_fx_feature_extraction = True +except ImportError: + has_fx_feature_extraction = False import timm from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ @@ -307,6 +311,9 @@ def test_model_forward_features(model_name, batch_size): @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx(model_name, batch_size): """Symbolically trace each model and run single forward pass through the resulting GraphModule""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + model = create_model(model_name, pretrained=False) model.eval() @@ -332,6 +339,9 @@ def test_model_forward_fx(model_name, batch_size): @pytest.mark.parametrize('batch_size', [2]) def test_model_backward_fx(model_name, batch_size): """Symbolically trace each model and run single backward pass through the resulting GraphModule""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) if max(input_size) > MAX_BWD_SIZE: pytest.skip("Fixed input size model > limit.") @@ -387,6 +397,9 @@ EXCLUDE_FX_JIT_FILTERS = [ @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx_torchscript(model_name, batch_size): """Symbolically trace each model, script it, and run single forward pass""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) if max(input_size) > MAX_JIT_SIZE: pytest.skip("Fixed input size model > limit.") diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index c8f296d4..a582cf9b 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -8,8 +8,9 @@ from .features import _get_feature_info try: from torchvision.models.feature_extraction import create_feature_extractor + has_fx_feature_extraction = True except ImportError: - pass + has_fx_feature_extraction = False # Layers we went to treat as leaf modules from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath @@ -58,6 +59,7 @@ def register_autowrap_function(func: Callable): class FeatureGraphNet(nn.Module): def __init__(self, model, out_indices, out_map=None): super().__init__() + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' self.feature_info = _get_feature_info(model, out_indices) if out_map is not None: assert len(out_map) == len(out_indices) @@ -66,7 +68,7 @@ class FeatureGraphNet(nn.Module): self.graph_module = create_feature_extractor( model, return_nodes, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - + def forward(self, x): return list(self.graph_module(x).values()) \ No newline at end of file