|
|
|
@ -399,8 +399,10 @@ def test_model_backward_fx(model_name, batch_size):
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|
|
|
|
|
|
model = create_model(model_name, pretrained=False, num_classes=42)
|
|
|
|
|
num_params = sum([x.numel() for x in model.parameters()])
|
|
|
|
|
model.train()
|
|
|
|
|
num_params = sum([x.numel() for x in model.parameters()])
|
|
|
|
|
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
|
|
|
|
|
pytest.skip("Skipping FX backward test on model with more than 100M params.")
|
|
|
|
|
|
|
|
|
|
model = _create_fx_model(model, train=True)
|
|
|
|
|
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
|
|
|
|