diff --git a/tests/test_models.py b/tests/test_models.py index 7ea9af6e..94744483 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,6 +4,8 @@ import platform import os import fnmatch +_IS_MAC = platform.system() == 'Darwin' + try: from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer has_fx_feature_extraction = True @@ -322,157 +324,160 @@ def test_model_forward_features(model_name, batch_size): assert not torch.isnan(o).any() -def _create_fx_model(model, train=False): - # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode - # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output - # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names - tracer_kwargs = dict( - leaf_modules=list(_leaf_modules), - autowrap_functions=list(_autowrap_functions), - #enable_cpatching=True, - param_shapes_constant=True - ) - train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs) - - eval_return_nodes = [eval_nodes[-1]] - train_return_nodes = [train_nodes[-1]] - if train: - tracer = NodePathTracer(**tracer_kwargs) - graph = tracer.trace(model) - graph_nodes = list(reversed(graph.nodes)) - output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] - graph_node_names = [n.name for n in graph_nodes] - output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] - train_return_nodes = [train_nodes[ix] for ix in output_node_indices] - - fx_model = create_feature_extractor( - model, - train_return_nodes=train_return_nodes, - eval_return_nodes=eval_return_nodes, - tracer_kwargs=tracer_kwargs, - ) - return fx_model - - -EXCLUDE_FX_FILTERS = ['vit_gi*'] -# not enough memory to run fx on more models than other tests -if 'GITHUB_ACTIONS' in os.environ: - EXCLUDE_FX_FILTERS += [ - 'beit_large*', - 'mixer_l*', - '*nfnet_f2*', - '*resnext101_32x32d', - 'resnetv2_152x2*', - 'resmlp_big*', - 'resnetrs270', - 'swin_large*', - 'vgg*', - 'vit_large*', - 'vit_base_patch8*', - 'xcit_large*', - ] +if not _IS_MAC: + # MACOS test runners are really slow, only running tests below this point if not on a Darwin runner... + + def _create_fx_model(model, train=False): + # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode + # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output + # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names + tracer_kwargs = dict( + leaf_modules=list(_leaf_modules), + autowrap_functions=list(_autowrap_functions), + #enable_cpatching=True, + param_shapes_constant=True + ) + train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs) + + eval_return_nodes = [eval_nodes[-1]] + train_return_nodes = [train_nodes[-1]] + if train: + tracer = NodePathTracer(**tracer_kwargs) + graph = tracer.trace(model) + graph_nodes = list(reversed(graph.nodes)) + output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] + graph_node_names = [n.name for n in graph_nodes] + output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] + train_return_nodes = [train_nodes[ix] for ix in output_node_indices] + + fx_model = create_feature_extractor( + model, + train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes, + tracer_kwargs=tracer_kwargs, + ) + return fx_model + + + EXCLUDE_FX_FILTERS = ['vit_gi*'] + # not enough memory to run fx on more models than other tests + if 'GITHUB_ACTIONS' in os.environ: + EXCLUDE_FX_FILTERS += [ + 'beit_large*', + 'mixer_l*', + '*nfnet_f2*', + '*resnext101_32x32d', + 'resnetv2_152x2*', + 'resmlp_big*', + 'resnetrs270', + 'swin_large*', + 'vgg*', + 'vit_large*', + 'vit_base_patch8*', + 'xcit_large*', + ] -@pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS)) -@pytest.mark.parametrize('batch_size', [1]) -def test_model_forward_fx(model_name, batch_size): - """ - Symbolically trace each model and run single forward pass through the resulting GraphModule - Also check that the output of a forward pass through the GraphModule is the same as that from the original Module - """ - if not has_fx_feature_extraction: - pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") + @pytest.mark.timeout(120) + @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS)) + @pytest.mark.parametrize('batch_size', [1]) + def test_model_forward_fx(model_name, batch_size): + """ + Symbolically trace each model and run single forward pass through the resulting GraphModule + Also check that the output of a forward pass through the GraphModule is the same as that from the original Module + """ + if not has_fx_feature_extraction: + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") - model = create_model(model_name, pretrained=False) - model.eval() + model = create_model(model_name, pretrained=False) + model.eval() - input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE) - if max(input_size) > MAX_FWD_FX_SIZE: - pytest.skip("Fixed input size model > limit.") - with torch.no_grad(): - inputs = torch.randn((batch_size, *input_size)) - outputs = model(inputs) - if isinstance(outputs, tuple): - outputs = torch.cat(outputs) + input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE) + if max(input_size) > MAX_FWD_FX_SIZE: + pytest.skip("Fixed input size model > limit.") + with torch.no_grad(): + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) - model = _create_fx_model(model) - fx_outputs = tuple(model(inputs).values()) - if isinstance(fx_outputs, tuple): - fx_outputs = torch.cat(fx_outputs) + model = _create_fx_model(model) + fx_outputs = tuple(model(inputs).values()) + if isinstance(fx_outputs, tuple): + fx_outputs = torch.cat(fx_outputs) - assert torch.all(fx_outputs == outputs) - assert outputs.shape[0] == batch_size - assert not torch.isnan(outputs).any(), 'Output included NaNs' + assert torch.all(fx_outputs == outputs) + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' -@pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models( - exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True)) -@pytest.mark.parametrize('batch_size', [2]) -def test_model_backward_fx(model_name, batch_size): - """Symbolically trace each model and run single backward pass through the resulting GraphModule""" - if not has_fx_feature_extraction: - pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") + @pytest.mark.timeout(120) + @pytest.mark.parametrize('model_name', list_models( + exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True)) + @pytest.mark.parametrize('batch_size', [2]) + def test_model_backward_fx(model_name, batch_size): + """Symbolically trace each model and run single backward pass through the resulting GraphModule""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") - input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE) - if max(input_size) > MAX_BWD_FX_SIZE: - pytest.skip("Fixed input size model > limit.") + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE) + if max(input_size) > MAX_BWD_FX_SIZE: + pytest.skip("Fixed input size model > limit.") - model = create_model(model_name, pretrained=False, num_classes=42) - model.train() - num_params = sum([x.numel() for x in model.parameters()]) - if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6: - pytest.skip("Skipping FX backward test on model with more than 100M params.") + model = create_model(model_name, pretrained=False, num_classes=42) + model.train() + num_params = sum([x.numel() for x in model.parameters()]) + if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6: + pytest.skip("Skipping FX backward test on model with more than 100M params.") - model = _create_fx_model(model, train=True) - outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) - if isinstance(outputs, tuple): - outputs = torch.cat(outputs) - outputs.mean().backward() - for n, x in model.named_parameters(): - assert x.grad is not None, f'No gradient for {n}' - num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) + model = _create_fx_model(model, train=True) + outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) + outputs.mean().backward() + for n, x in model.named_parameters(): + assert x.grad is not None, f'No gradient for {n}' + num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) - assert outputs.shape[-1] == 42 - assert num_params == num_grad, 'Some parameters are missing gradients' - assert not torch.isnan(outputs).any(), 'Output included NaNs' + assert outputs.shape[-1] == 42 + assert num_params == num_grad, 'Some parameters are missing gradients' + assert not torch.isnan(outputs).any(), 'Output included NaNs' -if 'GITHUB_ACTIONS' not in os.environ: - # FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process + if 'GITHUB_ACTIONS' not in os.environ: + # FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process - # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow - EXCLUDE_FX_JIT_FILTERS = [ - 'deit_*_distilled_patch16_224', - 'levit*', - 'pit_*_distilled_224', - ] + EXCLUDE_FX_FILTERS + # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow + EXCLUDE_FX_JIT_FILTERS = [ + 'deit_*_distilled_patch16_224', + 'levit*', + 'pit_*_distilled_224', + ] + EXCLUDE_FX_FILTERS - @pytest.mark.timeout(120) - @pytest.mark.parametrize( - 'model_name', list_models( - exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True)) - @pytest.mark.parametrize('batch_size', [1]) - def test_model_forward_fx_torchscript(model_name, batch_size): - """Symbolically trace each model, script it, and run single forward pass""" - if not has_fx_feature_extraction: - pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") + @pytest.mark.timeout(120) + @pytest.mark.parametrize( + 'model_name', list_models( + exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True)) + @pytest.mark.parametrize('batch_size', [1]) + def test_model_forward_fx_torchscript(model_name, batch_size): + """Symbolically trace each model, script it, and run single forward pass""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") - input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) - if max(input_size) > MAX_JIT_SIZE: - pytest.skip("Fixed input size model > limit.") + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) + if max(input_size) > MAX_JIT_SIZE: + pytest.skip("Fixed input size model > limit.") - with set_scriptable(True): - model = create_model(model_name, pretrained=False) - model.eval() + with set_scriptable(True): + model = create_model(model_name, pretrained=False) + model.eval() - model = torch.jit.script(_create_fx_model(model)) - with torch.no_grad(): - outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) - if isinstance(outputs, tuple): - outputs = torch.cat(outputs) + model = torch.jit.script(_create_fx_model(model)) + with torch.no_grad(): + outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) - assert outputs.shape[0] == batch_size - assert not torch.isnan(outputs).any(), 'Output included NaNs' + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs'