|
|
|
@ -310,7 +310,10 @@ def test_model_forward_features(model_name, batch_size):
|
|
|
|
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
|
|
|
|
@pytest.mark.parametrize('batch_size', [1])
|
|
|
|
|
def test_model_forward_fx(model_name, batch_size):
|
|
|
|
|
"""Symbolically trace each model and run single forward pass through the resulting GraphModule"""
|
|
|
|
|
"""
|
|
|
|
|
Symbolically trace each model and run single forward pass through the resulting GraphModule
|
|
|
|
|
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
|
|
|
|
|
"""
|
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
|
pytest.skip("Can't test FX because Torch >= 1.10 and Torchvision >= 0.11 are required")
|
|
|
|
|
|
|
|
|
@ -321,15 +324,32 @@ def test_model_forward_fx(model_name, batch_size):
|
|
|
|
|
if max(input_size) > MAX_FWD_SIZE:
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|
|
|
|
|
|
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
|
|
|
|
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
|
|
|
|
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
|
|
|
|
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
|
|
|
|
|
graph = tracer.trace(model)
|
|
|
|
|
graph_nodes = list(reversed(graph.nodes))
|
|
|
|
|
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
|
|
|
|
|
graph_node_names = [n.name for n in graph_nodes]
|
|
|
|
|
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
|
|
|
|
|
train_nodes, eval_nodes = get_graph_node_names(
|
|
|
|
|
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
|
|
|
|
model = create_feature_extractor(
|
|
|
|
|
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
|
|
|
|
|
eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices]
|
|
|
|
|
|
|
|
|
|
fx_model = create_feature_extractor(
|
|
|
|
|
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes,
|
|
|
|
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
|
|
|
|
|
|
|
|
|
inputs = torch.randn((batch_size, *input_size))
|
|
|
|
|
outputs = model(inputs)[eval_nodes[-1]]
|
|
|
|
|
outputs = model(inputs)
|
|
|
|
|
if isinstance(outputs, tuple):
|
|
|
|
|
outputs = torch.cat(outputs)
|
|
|
|
|
fx_outputs = tuple(fx_model(inputs).values())
|
|
|
|
|
if isinstance(fx_outputs, tuple):
|
|
|
|
|
fx_outputs = torch.cat(fx_outputs)
|
|
|
|
|
|
|
|
|
|
assert torch.all(fx_outputs == outputs)
|
|
|
|
|
assert outputs.shape[0] == batch_size
|
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
|
|
|
|
|
|
|
@ -348,6 +368,7 @@ def test_model_backward_fx(model_name, batch_size):
|
|
|
|
|
|
|
|
|
|
model = create_model(model_name, pretrained=False, num_classes=42)
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
num_params = sum([x.numel() for x in model.parameters()])
|
|
|
|
|
|
|
|
|
|
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
|
|
|
|
@ -355,7 +376,6 @@ def test_model_backward_fx(model_name, batch_size):
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|
|
|
|
|
|
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
|
|
|
|
# If so, we need to return all of them in order to check all grads
|
|
|
|
|
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
|
|
|
|
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
|
|
|
|
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
|
|
|
|
@ -385,9 +405,12 @@ def test_model_backward_fx(model_name, batch_size):
|
|
|
|
|
assert num_params == num_grad, 'Some parameters are missing gradients'
|
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
|
|
|
|
EXCLUDE_FX_JIT_FILTERS = [
|
|
|
|
|
'beit_*' # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
|
|
|
|
'beit_*',
|
|
|
|
|
'deit_*_distilled_patch16_224',
|
|
|
|
|
'levit*',
|
|
|
|
|
'pit_*_distilled_224',
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
@pytest.mark.timeout(120)
|
|
|
|
|