Add try/except guards

pull/800/head
Alexander Soare 3 years ago
parent b25ff96768
commit d2994016e9

@ -4,7 +4,11 @@ import platform
import os
import fnmatch
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
try:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False
import timm
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
@ -307,6 +311,9 @@ def test_model_forward_features(model_name, batch_size):
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size):
"""Symbolically trace each model and run single forward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
model = create_model(model_name, pretrained=False)
model.eval()
@ -332,6 +339,9 @@ def test_model_forward_fx(model_name, batch_size):
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
pytest.skip("Fixed input size model > limit.")
@ -387,6 +397,9 @@ EXCLUDE_FX_JIT_FILTERS = [
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")

@ -8,8 +8,9 @@ from .features import _get_feature_info
try:
from torchvision.models.feature_extraction import create_feature_extractor
has_fx_feature_extraction = True
except ImportError:
pass
has_fx_feature_extraction = False
# Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
@ -58,6 +59,7 @@ def register_autowrap_function(func: Callable):
class FeatureGraphNet(nn.Module):
def __init__(self, model, out_indices, out_map=None):
super().__init__()
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
self.feature_info = _get_feature_info(model, out_indices)
if out_map is not None:
assert len(out_map) == len(out_indices)

Loading…
Cancel
Save