Attempt to fix unit tests by removing subset of tests on mac runner

pull/1354/head
Ross Wightman 2 years ago
parent 326ade2999
commit 29afe79c8b

@ -4,6 +4,8 @@ import platform
import os import os
import fnmatch import fnmatch
_IS_MAC = platform.system() == 'Darwin'
try: try:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
has_fx_feature_extraction = True has_fx_feature_extraction = True
@ -322,157 +324,160 @@ def test_model_forward_features(model_name, batch_size):
assert not torch.isnan(o).any() assert not torch.isnan(o).any()
def _create_fx_model(model, train=False): if not _IS_MAC:
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode # MACOS test runners are really slow, only running tests below this point if not on a Darwin runner...
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names def _create_fx_model(model, train=False):
tracer_kwargs = dict( # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
leaf_modules=list(_leaf_modules), # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
autowrap_functions=list(_autowrap_functions), # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
#enable_cpatching=True, tracer_kwargs = dict(
param_shapes_constant=True leaf_modules=list(_leaf_modules),
) autowrap_functions=list(_autowrap_functions),
train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs) #enable_cpatching=True,
param_shapes_constant=True
eval_return_nodes = [eval_nodes[-1]] )
train_return_nodes = [train_nodes[-1]] train_nodes, eval_nodes = get_graph_node_names(model, tracer_kwargs=tracer_kwargs)
if train:
tracer = NodePathTracer(**tracer_kwargs) eval_return_nodes = [eval_nodes[-1]]
graph = tracer.trace(model) train_return_nodes = [train_nodes[-1]]
graph_nodes = list(reversed(graph.nodes)) if train:
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()] tracer = NodePathTracer(**tracer_kwargs)
graph_node_names = [n.name for n in graph_nodes] graph = tracer.trace(model)
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names] graph_nodes = list(reversed(graph.nodes))
train_return_nodes = [train_nodes[ix] for ix in output_node_indices] output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
fx_model = create_feature_extractor( output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
model, train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes, fx_model = create_feature_extractor(
tracer_kwargs=tracer_kwargs, model,
) train_return_nodes=train_return_nodes,
return fx_model eval_return_nodes=eval_return_nodes,
tracer_kwargs=tracer_kwargs,
)
EXCLUDE_FX_FILTERS = ['vit_gi*'] return fx_model
# not enough memory to run fx on more models than other tests
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [ EXCLUDE_FX_FILTERS = ['vit_gi*']
'beit_large*', # not enough memory to run fx on more models than other tests
'mixer_l*', if 'GITHUB_ACTIONS' in os.environ:
'*nfnet_f2*', EXCLUDE_FX_FILTERS += [
'*resnext101_32x32d', 'beit_large*',
'resnetv2_152x2*', 'mixer_l*',
'resmlp_big*', '*nfnet_f2*',
'resnetrs270', '*resnext101_32x32d',
'swin_large*', 'resnetv2_152x2*',
'vgg*', 'resmlp_big*',
'vit_large*', 'resnetrs270',
'vit_base_patch8*', 'swin_large*',
'xcit_large*', 'vgg*',
] 'vit_large*',
'vit_base_patch8*',
'xcit_large*',
]
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS)) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx(model_name, batch_size): def test_model_forward_fx(model_name, batch_size):
""" """
Symbolically trace each model and run single forward pass through the resulting GraphModule Symbolically trace each model and run single forward pass through the resulting GraphModule
Also check that the output of a forward pass through the GraphModule is the same as that from the original Module Also check that the output of a forward pass through the GraphModule is the same as that from the original Module
""" """
if not has_fx_feature_extraction: if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
model = create_model(model_name, pretrained=False) model = create_model(model_name, pretrained=False)
model.eval() model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE) input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
if max(input_size) > MAX_FWD_FX_SIZE: if max(input_size) > MAX_FWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.") pytest.skip("Fixed input size model > limit.")
with torch.no_grad(): with torch.no_grad():
inputs = torch.randn((batch_size, *input_size)) inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs) outputs = model(inputs)
if isinstance(outputs, tuple): if isinstance(outputs, tuple):
outputs = torch.cat(outputs) outputs = torch.cat(outputs)
model = _create_fx_model(model) model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values()) fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple): if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs) fx_outputs = torch.cat(fx_outputs)
assert torch.all(fx_outputs == outputs) assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs' assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models( @pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True)) exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2]) @pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size): def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule""" """Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction: if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE) input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_FX_SIZE: if max(input_size) > MAX_BWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.") pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42) model = create_model(model_name, pretrained=False, num_classes=42)
model.train() model.train()
num_params = sum([x.numel() for x in model.parameters()]) num_params = sum([x.numel() for x in model.parameters()])
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6: if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
pytest.skip("Skipping FX backward test on model with more than 100M params.") pytest.skip("Skipping FX backward test on model with more than 100M params.")
model = _create_fx_model(model, train=True) model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple): if isinstance(outputs, tuple):
outputs = torch.cat(outputs) outputs = torch.cat(outputs)
outputs.mean().backward() outputs.mean().backward()
for n, x in model.named_parameters(): for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}' assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42 assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients' assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs' assert not torch.isnan(outputs).any(), 'Output included NaNs'
if 'GITHUB_ACTIONS' not in os.environ: if 'GITHUB_ACTIONS' not in os.environ:
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process # FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [ EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224', 'deit_*_distilled_patch16_224',
'levit*', 'levit*',
'pit_*_distilled_224', 'pit_*_distilled_224',
] + EXCLUDE_FX_FILTERS ] + EXCLUDE_FX_FILTERS
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'model_name', list_models( 'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True)) exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size): def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass""" """Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction: if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE: if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.") pytest.skip("Fixed input size model > limit.")
with set_scriptable(True): with set_scriptable(True):
model = create_model(model_name, pretrained=False) model = create_model(model_name, pretrained=False)
model.eval() model.eval()
model = torch.jit.script(_create_fx_model(model)) model = torch.jit.script(_create_fx_model(model))
with torch.no_grad(): with torch.no_grad():
outputs = tuple(model(torch.randn((batch_size, *input_size))).values()) outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple): if isinstance(outputs, tuple):
outputs = torch.cat(outputs) outputs = torch.cat(outputs)
assert outputs.shape[0] == batch_size assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs' assert not torch.isnan(outputs).any(), 'Output included NaNs'

Loading…
Cancel
Save