fx ready for review

pull/800/head
Alexander Soare 3 years ago
parent d2994016e9
commit 0262a0e8e1

@ -310,7 +310,10 @@ def test_model_forward_features(model_name, batch_size):
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size):
"""Symbolically trace each model and run single forward pass through the resulting GraphModule"""
"""
Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
@ -321,15 +324,32 @@ def test_model_forward_fx(model_name, batch_size):
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices]
fx_model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)[eval_nodes[-1]]
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
fx_outputs = tuple(fx_model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@ -348,6 +368,7 @@ def test_model_backward_fx(model_name, batch_size):
model = create_model(model_name, pretrained=False, num_classes=42)
model.train()
num_params = sum([x.numel() for x in model.parameters()])
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
@ -355,7 +376,6 @@ def test_model_backward_fx(model_name, batch_size):
pytest.skip("Fixed input size model > limit.")
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# If so, we need to return all of them in order to check all grads
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
@ -385,9 +405,12 @@ def test_model_backward_fx(model_name, batch_size):
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'beit_*' # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
'beit_*',
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
]
@pytest.mark.timeout(120)

@ -26,6 +26,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_leaf_module
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
@ -318,6 +319,7 @@ class DownsampleAvg(nn.Module):
return self.conv(self.pool(x))
@register_leaf_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301
class NormFreeBlock(nn.Module):
"""Normalization-Free pre-activation block.
"""

Loading…
Cancel
Save