Attempt to reduce memory footprint of FX tests for GitHub actions runs

pull/989/head
Ross Wightman 3 years ago
parent bdd3dff0ca
commit 9b3519545d

@ -33,7 +33,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
EXCLUDE_FILTERS = [ EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'beit_large*'] '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*']
else: else:
EXCLUDE_FILTERS = [] EXCLUDE_FILTERS = []
@ -45,6 +45,10 @@ TARGET_JIT_SIZE = 128
MAX_JIT_SIZE = 320 MAX_JIT_SIZE = 320
TARGET_FFEAT_SIZE = 96 TARGET_FFEAT_SIZE = 96
MAX_FFEAT_SIZE = 256 MAX_FFEAT_SIZE = 256
TARGET_FWD_FX_SIZE = 128
MAX_FWD_FX_SIZE = 224
TARGET_BWD_FX_SIZE = 128
MAX_BWD_FX_SIZE = 224
def _get_input_size(model=None, model_name='', target=None): def _get_input_size(model=None, model_name='', target=None):
@ -306,6 +310,30 @@ def test_model_forward_features(model_name, batch_size):
assert not torch.isnan(o).any() assert not torch.isnan(o).any()
def _create_fx_model(model, train=False):
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
eval_return_nodes = [eval_nodes[-1]]
train_return_nodes = [train_nodes[-1]]
if train:
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
fx_model = create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
return fx_model
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
@ -320,32 +348,16 @@ def test_model_forward_fx(model_name, batch_size):
model = create_model(model_name, pretrained=False) model = create_model(model_name, pretrained=False)
model.eval() model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
if max(input_size) > MAX_FWD_SIZE: if max(input_size) > MAX_FWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.") pytest.skip("Fixed input size model > limit.")
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices]
fx_model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes,
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
inputs = torch.randn((batch_size, *input_size)) inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs) outputs = model(inputs)
if isinstance(outputs, tuple): if isinstance(outputs, tuple):
outputs = torch.cat(outputs) outputs = torch.cat(outputs)
fx_outputs = tuple(fx_model(inputs).values())
model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple): if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs) fx_outputs = torch.cat(fx_outputs)
@ -362,38 +374,16 @@ def test_model_backward_fx(model_name, batch_size):
if not has_fx_feature_extraction: if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_SIZE: if max(input_size) > MAX_BWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.") pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42) model = create_model(model_name, pretrained=False, num_classes=42)
model.train()
num_params = sum([x.numel() for x in model.parameters()]) num_params = sum([x.numel() for x in model.parameters()])
model.train()
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) model = _create_fx_model(model, train=True)
if max(input_size) > MAX_FWD_SIZE: outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
pytest.skip("Fixed input size model > limit.")
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
graph = tracer.trace(model)
graph_nodes = list(reversed(graph.nodes))
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
graph_node_names = [n.name for n in graph_nodes]
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
model = create_feature_extractor(
model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]],
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
inputs = torch.randn((batch_size, *input_size))
outputs = tuple(model(inputs).values())
if isinstance(outputs, tuple): if isinstance(outputs, tuple):
outputs = torch.cat(outputs) outputs = torch.cat(outputs)
outputs.mean().backward() outputs.mean().backward()
@ -412,6 +402,7 @@ EXCLUDE_FX_JIT_FILTERS = [
'pit_*_distilled_224', 'pit_*_distilled_224',
] ]
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'model_name', list_models( 'model_name', list_models(
@ -430,18 +421,10 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
model = create_model(model_name, pretrained=False) model = create_model(model_name, pretrained=False)
model.eval() model.eval()
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) model = torch.jit.script(_create_fx_model(model))
if max(input_size) > MAX_FWD_SIZE: outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
pytest.skip("Fixed input size model > limit.") if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
model = create_feature_extractor(
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]]
assert outputs.shape[0] == batch_size assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs' assert not torch.isnan(outputs).any(), 'Output included NaNs'

Loading…
Cancel
Save