|
|
|
@ -33,7 +33,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
|
|
|
|
|
EXCLUDE_FILTERS = [
|
|
|
|
|
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
|
|
|
|
|
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
|
|
|
|
|
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'beit_large*']
|
|
|
|
|
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*']
|
|
|
|
|
else:
|
|
|
|
|
EXCLUDE_FILTERS = []
|
|
|
|
|
|
|
|
|
@ -45,6 +45,10 @@ TARGET_JIT_SIZE = 128
|
|
|
|
|
MAX_JIT_SIZE = 320
|
|
|
|
|
TARGET_FFEAT_SIZE = 96
|
|
|
|
|
MAX_FFEAT_SIZE = 256
|
|
|
|
|
TARGET_FWD_FX_SIZE = 128
|
|
|
|
|
MAX_FWD_FX_SIZE = 224
|
|
|
|
|
TARGET_BWD_FX_SIZE = 128
|
|
|
|
|
MAX_BWD_FX_SIZE = 224
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_input_size(model=None, model_name='', target=None):
|
|
|
|
@ -306,6 +310,30 @@ def test_model_forward_features(model_name, batch_size):
|
|
|
|
|
assert not torch.isnan(o).any()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_fx_model(model, train=False):
|
|
|
|
|
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
|
|
|
|
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
|
|
|
|
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
|
|
|
|
train_nodes, eval_nodes = get_graph_node_names(
|
|
|
|
|
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
|
|
|
|
|
|
|
|
|
eval_return_nodes = [eval_nodes[-1]]
|
|
|
|
|
train_return_nodes = [train_nodes[-1]]
|
|
|
|
|
if train:
|
|
|
|
|
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
|
|
|
|
|
graph = tracer.trace(model)
|
|
|
|
|
graph_nodes = list(reversed(graph.nodes))
|
|
|
|
|
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
|
|
|
|
|
graph_node_names = [n.name for n in graph_nodes]
|
|
|
|
|
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
|
|
|
|
|
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
|
|
|
|
|
|
|
|
|
|
fx_model = create_feature_extractor(
|
|
|
|
|
model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes,
|
|
|
|
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
|
|
|
|
return fx_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.timeout(120)
|
|
|
|
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
|
|
|
|
@pytest.mark.parametrize('batch_size', [1])
|
|
|
|
@ -320,39 +348,23 @@ def test_model_forward_fx(model_name, batch_size):
|
|
|
|
|
model = create_model(model_name, pretrained=False)
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
|
|
|
|
|
if max(input_size) > MAX_FWD_SIZE:
|
|
|
|
|
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
|
|
|
|
|
if max(input_size) > MAX_FWD_FX_SIZE:
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|
|
|
|
|
|
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
|
|
|
|
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
|
|
|
|
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
|
|
|
|
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
|
|
|
|
|
graph = tracer.trace(model)
|
|
|
|
|
graph_nodes = list(reversed(graph.nodes))
|
|
|
|
|
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
|
|
|
|
|
graph_node_names = [n.name for n in graph_nodes]
|
|
|
|
|
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
|
|
|
|
|
train_nodes, eval_nodes = get_graph_node_names(
|
|
|
|
|
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
|
|
|
|
eval_return_nodes = [eval_nodes[ix] for ix in output_node_indices]
|
|
|
|
|
|
|
|
|
|
fx_model = create_feature_extractor(
|
|
|
|
|
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=eval_return_nodes,
|
|
|
|
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
|
|
|
|
|
|
|
|
|
inputs = torch.randn((batch_size, *input_size))
|
|
|
|
|
outputs = model(inputs)
|
|
|
|
|
if isinstance(outputs, tuple):
|
|
|
|
|
outputs = torch.cat(outputs)
|
|
|
|
|
fx_outputs = tuple(fx_model(inputs).values())
|
|
|
|
|
|
|
|
|
|
model = _create_fx_model(model)
|
|
|
|
|
fx_outputs = tuple(model(inputs).values())
|
|
|
|
|
if isinstance(fx_outputs, tuple):
|
|
|
|
|
fx_outputs = torch.cat(fx_outputs)
|
|
|
|
|
|
|
|
|
|
assert torch.all(fx_outputs == outputs)
|
|
|
|
|
assert outputs.shape[0] == batch_size
|
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.timeout(120)
|
|
|
|
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
|
|
|
|
@ -362,38 +374,16 @@ def test_model_backward_fx(model_name, batch_size):
|
|
|
|
|
if not has_fx_feature_extraction:
|
|
|
|
|
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
|
|
|
|
|
|
|
|
|
|
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
|
|
|
|
|
if max(input_size) > MAX_BWD_SIZE:
|
|
|
|
|
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
|
|
|
|
|
if max(input_size) > MAX_BWD_FX_SIZE:
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|
|
|
|
|
|
model = create_model(model_name, pretrained=False, num_classes=42)
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
num_params = sum([x.numel() for x in model.parameters()])
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
|
|
|
|
|
if max(input_size) > MAX_FWD_SIZE:
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|
|
|
|
|
|
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
|
|
|
|
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
|
|
|
|
# node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
|
|
|
|
|
tracer = NodePathTracer(leaf_modules=list(_leaf_modules), autowrap_functions=list(_autowrap_functions))
|
|
|
|
|
graph = tracer.trace(model)
|
|
|
|
|
graph_nodes = list(reversed(graph.nodes))
|
|
|
|
|
output_node_names = [n.name for n in graph_nodes[0]._input_nodes.keys()]
|
|
|
|
|
graph_node_names = [n.name for n in graph_nodes]
|
|
|
|
|
output_node_indices = [-graph_node_names.index(node_name) for node_name in output_node_names]
|
|
|
|
|
train_nodes, eval_nodes = get_graph_node_names(
|
|
|
|
|
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
|
|
|
|
train_return_nodes = [train_nodes[ix] for ix in output_node_indices]
|
|
|
|
|
|
|
|
|
|
model = create_feature_extractor(
|
|
|
|
|
model, train_return_nodes=train_return_nodes, eval_return_nodes=[eval_nodes[-1]],
|
|
|
|
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
|
|
|
|
|
|
|
|
|
inputs = torch.randn((batch_size, *input_size))
|
|
|
|
|
outputs = tuple(model(inputs).values())
|
|
|
|
|
model = _create_fx_model(model, train=True)
|
|
|
|
|
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
|
|
|
|
if isinstance(outputs, tuple):
|
|
|
|
|
outputs = torch.cat(outputs)
|
|
|
|
|
outputs.mean().backward()
|
|
|
|
@ -412,6 +402,7 @@ EXCLUDE_FX_JIT_FILTERS = [
|
|
|
|
|
'pit_*_distilled_224',
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.timeout(120)
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
'model_name', list_models(
|
|
|
|
@ -430,18 +421,10 @@ def test_model_forward_fx_torchscript(model_name, batch_size):
|
|
|
|
|
model = create_model(model_name, pretrained=False)
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
|
|
|
|
|
if max(input_size) > MAX_FWD_SIZE:
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|
|
|
|
|
|
train_nodes, eval_nodes = get_graph_node_names(
|
|
|
|
|
model, tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
|
|
|
|
model = create_feature_extractor(
|
|
|
|
|
model, train_return_nodes=[train_nodes[-1]], eval_return_nodes=[eval_nodes[-1]],
|
|
|
|
|
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
|
|
|
|
|
|
|
|
|
|
model = torch.jit.script(model)
|
|
|
|
|
outputs = model(torch.randn((batch_size, *input_size)))[train_nodes[-1]]
|
|
|
|
|
model = torch.jit.script(_create_fx_model(model))
|
|
|
|
|
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
|
|
|
|
|
if isinstance(outputs, tuple):
|
|
|
|
|
outputs = torch.cat(outputs)
|
|
|
|
|
|
|
|
|
|
assert outputs.shape[0] == batch_size
|
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
|
|
|