|
|
|
@ -4,7 +4,11 @@ import platform
|
|
|
|
|
import os
|
|
|
|
|
import fnmatch
|
|
|
|
|
|
|
|
|
|
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
|
|
|
|
|
try:
|
|
|
|
|
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
|
|
|
|
|
has_fx_feature_extraction = True
|
|
|
|
|
except ImportError:
|
|
|
|
|
has_fx_feature_extraction = False
|
|
|
|
|
|
|
|
|
|
import timm
|
|
|
|
|
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
|
|
|
|
@ -307,6 +311,9 @@ def test_model_forward_features(model_name, batch_size):
|
|
|
|
|
@pytest.mark.parametrize('batch_size', [1])
|
|
|
|
|
def test_model_forward_fx(model_name, batch_size):
|
|
|
|
|
"""Symbolically trace each model and run single forward pass through the resulting GraphModule"""
|
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
|
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
|
|
|
|
|
|
|
|
|
|
model = create_model(model_name, pretrained=False)
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
@ -332,6 +339,9 @@ def test_model_forward_fx(model_name, batch_size):
|
|
|
|
|
@pytest.mark.parametrize('batch_size', [2])
|
|
|
|
|
def test_model_backward_fx(model_name, batch_size):
|
|
|
|
|
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
|
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
|
|
|
|
|
|
|
|
|
|
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
|
|
|
|
|
if max(input_size) > MAX_BWD_SIZE:
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
@ -387,6 +397,9 @@ EXCLUDE_FX_JIT_FILTERS = [
|
|
|
|
|
@pytest.mark.parametrize('batch_size', [1])
|
|
|
|
|
def test_model_forward_fx_torchscript(model_name, batch_size):
|
|
|
|
|
"""Symbolically trace each model, script it, and run single forward pass"""
|
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
|
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
|
|
|
|
|
|
|
|
|
|
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
|
|
|
|
if max(input_size) > MAX_JIT_SIZE:
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|