From 1076a65df1fa8fe0612adbe21284f4722b585dac Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 18 Nov 2021 19:47:07 -0800 Subject: [PATCH] Minor post FX merge cleanup --- tests/test_models.py | 6 +++--- timm/models/fx_features.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 93152d9a..7a3f143e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -315,7 +315,7 @@ def test_model_forward_fx(model_name, batch_size): Also check that the output of a forward pass through the GraphModule is the same as that from the original Module """ if not has_fx_feature_extraction: - pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") model = create_model(model_name, pretrained=False) model.eval() @@ -360,7 +360,7 @@ def test_model_forward_fx(model_name, batch_size): def test_model_backward_fx(model_name, batch_size): """Symbolically trace each model and run single backward pass through the resulting GraphModule""" if not has_fx_feature_extraction: - pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) if max(input_size) > MAX_BWD_SIZE: @@ -421,7 +421,7 @@ EXCLUDE_FX_JIT_FILTERS = [ def test_model_forward_fx_torchscript(model_name, batch_size): """Symbolically trace each model, script it, and run single forward pass""" if not has_fx_feature_extraction: - pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) if max(input_size) > MAX_JIT_SIZE: diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index 2e01586b..5a25ee3e 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -71,4 +71,3 @@ class FeatureGraphNet(nn.Module): def forward(self, x): return list(self.graph_module(x).values()) - \ No newline at end of file