diff --git a/tests/test_models.py b/tests/test_models.py index c0d0e901..91fd3543 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -7,6 +7,7 @@ import fnmatch import timm from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \ get_model_default_value +from timm.models.fx_features import NodePathTracer if hasattr(torch._C, '_jit_set_profiling_executor'): # legacy executor is too slow to compile large models for unit tests @@ -297,3 +298,82 @@ def test_model_forward_features(model_name, batch_size): assert e == o.shape[1] assert o.shape[0] == batch_size assert not torch.isnan(o).any() + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_fx(model_name, batch_size): + """Symbolically trace each model and run single forward pass through the resulting GraphModule""" + model = create_model(model_name, pretrained=False) + model.eval() + + input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") + + tracer = NodePathTracer() + graph = tracer.trace(model) + model = torch.fx.GraphModule(model, graph) + + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) +@pytest.mark.parametrize('batch_size', [2]) +def test_model_backward_fx(model_name, batch_size): + """Symbolically trace each model and run single backward pass through the resulting GraphModule""" + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) + if max(input_size) > MAX_BWD_SIZE: + pytest.skip("Fixed input size model > limit.") + + model = create_model(model_name, pretrained=False, num_classes=42) + model.train() + num_params = sum([x.numel() for x in model.parameters()]) + + tracer = NodePathTracer() + graph = tracer.trace(model) + model = torch.fx.GraphModule(model, graph) + + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) + outputs.mean().backward() + for n, x in model.named_parameters(): + assert x.grad is not None, f'No gradient for {n}' + num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) + + assert outputs.shape[-1] == 42 + assert num_params == num_grad, 'Some parameters are missing gradients' + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize( + 'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward_fx_torchscript(model_name, batch_size): + """Symbolically trace each model, script it, and run single forward pass""" + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) + if max(input_size) > MAX_JIT_SIZE: + pytest.skip("Fixed input size model > limit.") + + with set_scriptable(True): + model = create_model(model_name, pretrained=False) + model.eval() + + tracer = NodePathTracer() + graph = tracer.trace(model) + model = torch.fx.GraphModule(model, graph) + + model = torch.jit.script(model) + outputs = model(torch.randn((batch_size, *input_size))) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' \ No newline at end of file