diff --git a/tests/test_models.py b/tests/test_models.py index e513dcaf..93152d9a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -310,7 +310,10 @@ def test_model_forward_features(model_name, batch_size): @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_fx(model_name, batch_size): - """Symbolically trace each model and run single forward pass through the resulting GraphModule""" + """ + Symbolically trace each model and run single forward pass through the resulting GraphModule + Also check that the output of a forward pass through the GraphModule is the same as that from the original Module + """ if not has_fx_feature_extraction: pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required") @@ -321,15 +324,32 @@ def test_model_forward_fx(model_name, batch_size): if max(input_size) > MAX_FWD_SIZE: pytest.skip("Fixed input size model > limit.") + # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode + # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output + # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names + tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) + graph = tracer.trace(model) + graph_nodes = list(reversed(graph.nodes)) + output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] + graph_node_names = [n.name for n in graph_nodes] + output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] train_nodes, eval_nodes = get_graph_node_names( model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - model = create_feature_extractor( - model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]], + eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices] + + fx_model = create_feature_extractor( + model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) inputs = torch.randn((batch_size, *input_size)) - outputs = model(inputs)[eval_nodes[-1]] + outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) + fx_outputs = tuple(fx_model(inputs).values()) + if isinstance(fx_outputs, tuple): + fx_outputs = torch.cat(fx_outputs) + assert torch.all(fx_outputs == outputs) assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' @@ -348,6 +368,7 @@ def test_model_backward_fx(model_name, batch_size): model = create_model(model_name, pretrained=False, num_classes=42) model.train() + num_params = sum([x.numel() for x in model.parameters()]) input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) @@ -355,7 +376,6 @@ def test_model_backward_fx(model_name, batch_size): pytest.skip("Fixed input size model > limit.") # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode - # If so, we need to return all of them in order to check all grads # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) @@ -385,9 +405,12 @@ def test_model_backward_fx(model_name, batch_size): assert num_params == num_grad, 'Some parameters are missing gradients' assert not torch.isnan(outputs).any(), 'Output included NaNs' - +# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow EXCLUDE_FX_JIT_FILTERS = [ - 'beit_*' # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow + 'beit_*', + 'deit_*_distilled_patch16_224', + 'levit*', + 'pit_*_distilled_224', ] @pytest.mark.timeout(120) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 4e0f2b21..1d6cbb38 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -26,6 +26,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_leaf_module from .helpers import build_model_with_cfg from .registry import register_model from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\ @@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module): return self.conv(self.pool(x)) +@register_leaf_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301 class NormFreeBlock(nn.Module): """Normalization-Free pre-activation block. """