|
|
@ -386,37 +386,41 @@ def test_model_forward_fx(model_name, batch_size):
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.timeout(120)
|
|
|
|
if 'GITHUB_ACTIONS' not in os.environ:
|
|
|
|
@pytest.mark.parametrize('model_name', list_models(
|
|
|
|
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
|
|
|
|
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('batch_size', [2])
|
|
|
|
|
|
|
|
def test_model_backward_fx(model_name, batch_size):
|
|
|
|
|
|
|
|
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
|
|
|
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
|
|
|
|
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
|
|
|
|
|
|
|
|
if max(input_size) > MAX_BWD_FX_SIZE:
|
|
|
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = create_model(model_name, pretrained=False, num_classes=42)
|
|
|
|
@pytest.mark.timeout(120)
|
|
|
|
model.train()
|
|
|
|
@pytest.mark.parametrize('model_name', list_models(
|
|
|
|
num_params = sum([x.numel() for x in model.parameters()])
|
|
|
|
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
|
|
|
|
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
|
|
|
|
@pytest.mark.parametrize('batch_size', [2])
|
|
|
|
pytest.skip("Skipping FX backward test on model with more than 100M params.")
|
|
|
|
def test_model_backward_fx(model_name, batch_size):
|
|
|
|
|
|
|
|
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
|
|
|
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
|
|
|
|
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
|
|
|
|
|
|
|
|
if max(input_size) > MAX_BWD_FX_SIZE:
|
|
|
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = create_model(model_name, pretrained=False, num_classes=42)
|
|
|
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
num_params = sum([x.numel() for x in model.parameters()])
|
|
|
|
|
|
|
|
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
|
|
|
|
|
|
|
|
pytest.skip("Skipping FX backward test on model with more than 100M params.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = _create_fx_model(model, train=True)
|
|
|
|
|
|
|
|
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
|
|
|
|
|
|
|
if isinstance(outputs, tuple):
|
|
|
|
|
|
|
|
outputs = torch.cat(outputs)
|
|
|
|
|
|
|
|
outputs.mean().backward()
|
|
|
|
|
|
|
|
for n, x in model.named_parameters():
|
|
|
|
|
|
|
|
assert x.grad is not None, f'No gradient for {n}'
|
|
|
|
|
|
|
|
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
|
|
|
|
|
|
|
|
|
|
|
|
model = _create_fx_model(model, train=True)
|
|
|
|
assert outputs.shape[-1] == 42
|
|
|
|
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
|
|
|
assert num_params == num_grad, 'Some parameters are missing gradients'
|
|
|
|
if isinstance(outputs, tuple):
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
|
|
outputs = torch.cat(outputs)
|
|
|
|
|
|
|
|
outputs.mean().backward()
|
|
|
|
|
|
|
|
for n, x in model.named_parameters():
|
|
|
|
|
|
|
|
assert x.grad is not None, f'No gradient for {n}'
|
|
|
|
|
|
|
|
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert outputs.shape[-1] == 42
|
|
|
|
|
|
|
|
assert num_params == num_grad, 'Some parameters are missing gradients'
|
|
|
|
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
|
|
|
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
|
|
|
EXCLUDE_FX_JIT_FILTERS = [
|
|
|
|
EXCLUDE_FX_JIT_FILTERS = [
|
|
|
|