|
|
|
@ -7,6 +7,7 @@ import fnmatch
|
|
|
|
|
import timm
|
|
|
|
|
from timm import list_models, create_model, set_scriptable, has_model_default_key, is_model_default_key, \
|
|
|
|
|
get_model_default_value
|
|
|
|
|
from timm.models.fx_features import NodePathTracer
|
|
|
|
|
|
|
|
|
|
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
|
|
|
|
# legacy executor is too slow to compile large models for unit tests
|
|
|
|
@ -297,3 +298,82 @@ def test_model_forward_features(model_name, batch_size):
|
|
|
|
|
assert e == o.shape[1]
|
|
|
|
|
assert o.shape[0] == batch_size
|
|
|
|
|
assert not torch.isnan(o).any()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.timeout(120)
|
|
|
|
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
|
|
|
|
@pytest.mark.parametrize('batch_size', [1])
|
|
|
|
|
def test_model_forward_fx(model_name, batch_size):
|
|
|
|
|
"""Symbolically trace each model and run single forward pass through the resulting GraphModule"""
|
|
|
|
|
model = create_model(model_name, pretrained=False)
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
|
|
|
|
|
if max(input_size) > MAX_FWD_SIZE:
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|
|
|
|
|
|
tracer = NodePathTracer()
|
|
|
|
|
graph = tracer.trace(model)
|
|
|
|
|
model = torch.fx.GraphModule(model, graph)
|
|
|
|
|
|
|
|
|
|
inputs = torch.randn((batch_size, *input_size))
|
|
|
|
|
outputs = model(inputs)
|
|
|
|
|
|
|
|
|
|
assert outputs.shape[0] == batch_size
|
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.timeout(120)
|
|
|
|
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
|
|
|
|
|
@pytest.mark.parametrize('batch_size', [2])
|
|
|
|
|
def test_model_backward_fx(model_name, batch_size):
|
|
|
|
|
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
|
|
|
|
|
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
|
|
|
|
|
if max(input_size) > MAX_BWD_SIZE:
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|
|
|
|
|
|
model = create_model(model_name, pretrained=False, num_classes=42)
|
|
|
|
|
model.train()
|
|
|
|
|
num_params = sum([x.numel() for x in model.parameters()])
|
|
|
|
|
|
|
|
|
|
tracer = NodePathTracer()
|
|
|
|
|
graph = tracer.trace(model)
|
|
|
|
|
model = torch.fx.GraphModule(model, graph)
|
|
|
|
|
|
|
|
|
|
inputs = torch.randn((batch_size, *input_size))
|
|
|
|
|
outputs = model(inputs)
|
|
|
|
|
if isinstance(outputs, tuple):
|
|
|
|
|
outputs = torch.cat(outputs)
|
|
|
|
|
outputs.mean().backward()
|
|
|
|
|
for n, x in model.named_parameters():
|
|
|
|
|
assert x.grad is not None, f'No gradient for {n}'
|
|
|
|
|
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
|
|
|
|
|
|
|
|
|
|
assert outputs.shape[-1] == 42
|
|
|
|
|
assert num_params == num_grad, 'Some parameters are missing gradients'
|
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.timeout(120)
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
|
|
|
|
|
@pytest.mark.parametrize('batch_size', [1])
|
|
|
|
|
def test_model_forward_fx_torchscript(model_name, batch_size):
|
|
|
|
|
"""Symbolically trace each model, script it, and run single forward pass"""
|
|
|
|
|
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
|
|
|
|
|
if max(input_size) > MAX_JIT_SIZE:
|
|
|
|
|
pytest.skip("Fixed input size model > limit.")
|
|
|
|
|
|
|
|
|
|
with set_scriptable(True):
|
|
|
|
|
model = create_model(model_name, pretrained=False)
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
tracer = NodePathTracer()
|
|
|
|
|
graph = tracer.trace(model)
|
|
|
|
|
model = torch.fx.GraphModule(model, graph)
|
|
|
|
|
|
|
|
|
|
model = torch.jit.script(model)
|
|
|
|
|
outputs = model(torch.randn((batch_size, *input_size)))
|
|
|
|
|
|
|
|
|
|
assert outputs.shape[0] == batch_size
|
|
|
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|