Minor post FX merge cleanup

pull/989/head
Ross Wightman 3 years ago
parent 32c9937dec
commit 1076a65df1

@ -315,7 +315,7 @@ def test_model_forward_fx(model_name, batch_size):
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
model = create_model(model_name, pretrained=False)
model.eval()
@ -360,7 +360,7 @@ def test_model_forward_fx(model_name, batch_size):
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
@ -421,7 +421,7 @@ EXCLUDE_FX_JIT_FILTERS = [
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:

@ -71,4 +71,3 @@ class FeatureGraphNet(nn.Module):
def forward(self, x):
return list(self.graph_module(x).values())
Loading…
Cancel
Save