diff --git a/tests/test_models.py b/tests/test_models.py index f4520720..1750d540 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -33,7 +33,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', - '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'beit_large*'] + '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*'] else: EXCLUDE_FILTERS = [] @@ -45,6 +45,10 @@ TARGET_JIT_SIZE = 128 MAX_JIT_SIZE = 320 TARGET_FFEAT_SIZE = 96 MAX_FFEAT_SIZE = 256 +TARGET_FWD_FX_SIZE = 128 +MAX_FWD_FX_SIZE = 224 +TARGET_BWD_FX_SIZE = 128 +MAX_BWD_FX_SIZE = 224 def _get_input_size(model=None, model_name='', target=None): @@ -306,6 +310,30 @@ def test_model_forward_features(model_name, batch_size): assert not torch.isnan(o).any() +def _create_fx_model(model, train=False): + # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode + # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output + # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names + train_nodes, eval_nodes = get_graph_node_names( + model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + + eval_return_nodes = [eval_nodes[-1]] + train_return_nodes = [train_nodes[-1]] + if train: + tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) + graph = tracer.trace(model) + graph_nodes = list(reversed(graph.nodes)) + output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] + graph_node_names = [n.name for n in graph_nodes] + output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] + train_return_nodes = [train_nodes[ix] for ix in output_node_indices] + + fx_model = create_feature_extractor( + model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) + return fx_model + + @pytest.mark.timeout(120) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) @@ -320,39 +348,23 @@ def test_model_forward_fx(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) - if max(input_size) > MAX_FWD_SIZE: + input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE) + if max(input_size) > MAX_FWD_FX_SIZE: pytest.skip("Fixed input size model > limit.") - - # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode - # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output - # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names - tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) - graph = tracer.trace(model) - graph_nodes = list(reversed(graph.nodes)) - output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] - graph_node_names = [n.name for n in graph_nodes] - output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] - train_nodes, eval_nodes = get_graph_node_names( - model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices] - - fx_model = create_feature_extractor( - model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes, - tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) if isinstance(outputs, tuple): outputs = torch.cat(outputs) - fx_outputs = tuple(fx_model(inputs).values()) + + model = _create_fx_model(model) + fx_outputs = tuple(model(inputs).values()) if isinstance(fx_outputs, tuple): fx_outputs = torch.cat(fx_outputs) assert torch.all(fx_outputs == outputs) assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' - + @pytest.mark.timeout(120) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) @@ -362,38 +374,16 @@ def test_model_backward_fx(model_name, batch_size): if not has_fx_feature_extraction: pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") - input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) - if max(input_size) > MAX_BWD_SIZE: + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE) + if max(input_size) > MAX_BWD_FX_SIZE: pytest.skip("Fixed input size model > limit.") model = create_model(model_name, pretrained=False, num_classes=42) - model.train() - num_params = sum([x.numel() for x in model.parameters()]) + model.train() - input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) - if max(input_size) > MAX_FWD_SIZE: - pytest.skip("Fixed input size model > limit.") - - # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode - # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output - # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names - tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions)) - graph = tracer.trace(model) - graph_nodes = list(reversed(graph.nodes)) - output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] - graph_node_names = [n.name for n in graph_nodes] - output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] - train_nodes, eval_nodes = get_graph_node_names( - model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - train_return_nodes = [train_nodes[ix] for ix in output_node_indices] - - model = create_feature_extractor( - model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]], - tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - - inputs = torch.randn((batch_size, *input_size)) - outputs = tuple(model(inputs).values()) + model = _create_fx_model(model, train=True) + outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) if isinstance(outputs, tuple): outputs = torch.cat(outputs) outputs.mean().backward() @@ -412,6 +402,7 @@ EXCLUDE_FX_JIT_FILTERS = [ 'pit_*_distilled_224', ] + @pytest.mark.timeout(120) @pytest.mark.parametrize( 'model_name', list_models( @@ -430,18 +421,10 @@ def test_model_forward_fx_torchscript(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) - if max(input_size) > MAX_FWD_SIZE: - pytest.skip("Fixed input size model > limit.") - - train_nodes, eval_nodes = get_graph_node_names( - model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - model = create_feature_extractor( - model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]], - tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}) - - model = torch.jit.script(model) - outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]] + model = torch.jit.script(_create_fx_model(model)) + outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs'