|
|
@ -315,7 +315,7 @@ def test_model_forward_fx(model_name, batch_size):
|
|
|
|
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
|
|
|
|
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
|
|
|
|
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
|
|
|
|
|
|
|
|
|
|
|
model = create_model(model_name, pretrained=False)
|
|
|
|
model = create_model(model_name, pretrained=False)
|
|
|
|
model.eval()
|
|
|
|
model.eval()
|
|
|
@ -360,7 +360,7 @@ def test_model_forward_fx(model_name, batch_size):
|
|
|
|
def test_model_backward_fx(model_name, batch_size):
|
|
|
|
def test_model_backward_fx(model_name, batch_size):
|
|
|
|
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
|
|
|
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
|
|
|
|
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
|
|
|
|
|
|
|
|
|
|
|
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
|
|
|
|
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
|
|
|
|
if max(input_size) > MAX_BWD_SIZE:
|
|
|
|
if max(input_size) > MAX_BWD_SIZE:
|
|
|
@ -421,7 +421,7 @@ EXCLUDE_FX_JIT_FILTERS = [
|
|
|
|
def test_model_forward_fx_torchscript(model_name, batch_size):
|
|
|
|
def test_model_forward_fx_torchscript(model_name, batch_size):
|
|
|
|
"""Symbolically trace each model, script it, and run single forward pass"""
|
|
|
|
"""Symbolically trace each model, script it, and run single forward pass"""
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
|
|
|
|
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
|
|
|
|
|
|
|
|
|
|
|
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
|
|
|
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
|
|
|
if max(input_size) > MAX_JIT_SIZE:
|
|
|
|
if max(input_size) > MAX_JIT_SIZE:
|
|
|
|