From afb6bd066910e71fb7b2621ab098a3826ccc27b8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 21 May 2020 15:28:36 -0700 Subject: [PATCH] Add backward and default_cfg tests and fix a few issues found. Fix #153 --- tests/test_inference.py | 19 ---------- tests/test_models.py | 70 +++++++++++++++++++++++++++++++++++ timm/models/dla.py | 10 +++-- timm/models/gluon_xception.py | 2 +- timm/models/hrnet.py | 2 +- timm/models/inception_v3.py | 2 +- timm/models/mobilenetv3.py | 2 +- timm/models/nasnet.py | 4 +- timm/models/resnest.py | 9 +++-- timm/models/selecsls.py | 2 +- timm/models/tresnet.py | 8 ++-- timm/models/xception.py | 1 + 12 files changed, 95 insertions(+), 36 deletions(-) delete mode 100644 tests/test_inference.py create mode 100644 tests/test_models.py diff --git a/tests/test_inference.py b/tests/test_inference.py deleted file mode 100644 index 2490a0bc..00000000 --- a/tests/test_inference.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest -import torch - -from timm import list_models, create_model - - -@pytest.mark.timeout(300) -@pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*')) -@pytest.mark.parametrize('batch_size', [1]) -def test_model_forward(model_name, batch_size): - """Run a single forward pass with each model""" - model = create_model(model_name, pretrained=False) - model.eval() - - inputs = torch.randn((batch_size, *model.default_cfg['input_size'])) - outputs = model(inputs) - - assert outputs.shape[0] == batch_size - assert not torch.isnan(outputs).any(), 'Output included NaNs' diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 00000000..65a7ebb3 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,70 @@ +import pytest +import torch + +from timm import list_models, create_model + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models()) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_forward(model_name, batch_size): + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False) + model.eval() + + input_size = model.default_cfg['input_size'] + if any([x > 448 for x in input_size]): + # cap forward test at max res 448 * 448 to keep resource down + input_size = tuple([min(x, 448) for x in input_size]) + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + + assert outputs.shape[0] == batch_size + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models(exclude_filters='dla*')) # DLA models have an issue TBD +@pytest.mark.parametrize('batch_size', [2]) +def test_model_backward(model_name, batch_size): + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False, num_classes=42) + num_params = sum([x.numel() for x in model.parameters()]) + model.eval() + + input_size = model.default_cfg['input_size'] + if any([x > 128 for x in input_size]): + # cap backward test at 128 * 128 to keep resource usage down + input_size = tuple([min(x, 128) for x in input_size]) + inputs = torch.randn((batch_size, *input_size)) + outputs = model(inputs) + outputs.mean().backward() + num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None]) + + assert outputs.shape[-1] == 42 + assert num_params == num_grad, 'Some parameters are missing gradients' + assert not torch.isnan(outputs).any(), 'Output included NaNs' + + +@pytest.mark.timeout(120) +@pytest.mark.parametrize('model_name', list_models()) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_default_cfgs(model_name, batch_size): + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False) + model.eval() + state_dict = model.state_dict() + cfg = model.default_cfg + + classifier = cfg['classifier'] + first_conv = cfg['first_conv'] + pool_size = cfg['pool_size'] + input_size = model.default_cfg['input_size'] + + if all([x <= 448 for x in input_size]): + # pool size only checked if default res <= 448 * 448 to keep resource down + input_size = tuple([min(x, 448) for x in input_size]) + outputs = model.forward_features(torch.randn((batch_size, *input_size))) + assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] + assert any([k.startswith(cfg['classifier']) for k in state_dict.keys()]), f'{classifier} not in model params' + assert any([k.startswith(cfg['first_conv']) for k in state_dict.keys()]), f'{first_conv} not in model params' diff --git a/timm/models/dla.py b/timm/models/dla.py index a9e81d16..94803e69 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -237,8 +237,11 @@ class DlaTree(nn.Module): def forward(self, x, residual=None, children=None): children = [] if children is None else children - bottom = self.downsample(x) if self.downsample else x - residual = self.project(bottom) if self.project else bottom + # FIXME the way downsample / project are used here and residual is passed to next level up + # the tree, the residual is overridden and some project weights are thus never used and + # have no gradients. This appears to be an issue with the original model / weights. + bottom = self.downsample(x) if self.downsample is not None else x + residual = self.project(bottom) if self.project is not None else bottom if self.level_root: children.append(bottom) x1 = self.tree1(x, residual) @@ -354,7 +357,8 @@ def dla60_res2next(pretrained=None, num_classes=1000, in_chans=3, **kwargs): @register_model def dla34(pretrained=None, num_classes=1000, in_chans=3, **kwargs): # DLA-34 default_cfg = default_cfgs['dla34'] - model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic, **kwargs) + model = DLA([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=DlaBasic, + num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 2fc8e699..a737b8f7 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -36,7 +36,7 @@ default_cfgs = { 'url': '', 'input_size': (3, 299, 299), 'crop_pct': 0.875, - 'pool_size': (10, 10), + 'pool_size': (5, 5), 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 06327c65..ac4824bb 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -34,7 +34,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'conv1', 'classifier': 'fc', + 'first_conv': 'conv1', 'classifier': 'classifier', **kwargs } diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index 0997e024..ffaab4f1 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -15,7 +15,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'conv1', 'classifier': 'fc', + 'first_conv': 'Conv2d_1a_3x3', 'classifier': 'fc', **kwargs } diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 86ca9f7a..9c0e863a 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -21,7 +21,7 @@ __all__ = ['MobileNetV3'] def _cfg(url='', **kwargs): return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'conv_stem', 'classifier': 'classifier', diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 8847b1de..511b006b 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -19,7 +19,7 @@ default_cfgs = { 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'num_classes': 1001, - 'first_conv': 'conv_0.conv', + 'first_conv': 'conv0.conv', 'classifier': 'last_linear', }, } @@ -612,7 +612,7 @@ def nasnetalarge(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """NASNet-A large model architecture. """ default_cfg = default_cfgs['nasnetalarge'] - model = NASNetALarge(num_classes=1000, in_chans=in_chans, **kwargs) + model = NASNetALarge(num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 33b051ef..884894d9 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -38,11 +38,14 @@ default_cfgs = { 'resnest50d': _cfg( url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'), 'resnest101e': _cfg( - url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', input_size=(3, 256, 256)), + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), 'resnest200e': _cfg( - url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)), + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', + input_size=(3, 320, 320), pool_size=(10, 10)), 'resnest269e': _cfg( - url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)), + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', + input_size=(3, 416, 416), pool_size=(13, 13)), 'resnest50d_4s2x40d': _cfg( url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_4s2x40d-41d14ed0.pth', interpolation='bicubic'), diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 2f369e99..6b83421b 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -26,7 +26,7 @@ __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (3, 3), + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem', 'classifier': 'fc', diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 48b3e1de..fbbcf318 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -28,7 +28,7 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': (0, 0, 0), 'std': (1, 1, 1), - 'first_conv': 'layer0.conv1', 'classifier': 'head.fc', + 'first_conv': 'body.conv1', 'classifier': 'head.fc', **kwargs } @@ -41,13 +41,13 @@ default_cfgs = { 'tresnet_xl': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'), 'tresnet_m_448': _cfg( - input_size=(3, 448, 448), + input_size=(3, 448, 448), pool_size=(14, 14), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'), 'tresnet_l_448': _cfg( - input_size=(3, 448, 448), + input_size=(3, 448, 448), pool_size=(14, 14), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'), 'tresnet_xl_448': _cfg( - input_size=(3, 448, 448), + input_size=(3, 448, 448), pool_size=(14, 14), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth') } diff --git a/timm/models/xception.py b/timm/models/xception.py index cb98bbc9..f04dabfd 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -37,6 +37,7 @@ default_cfgs = { 'xception': { 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth', 'input_size': (3, 299, 299), + 'pool_size': (10, 10), 'crop_pct': 0.8975, 'interpolation': 'bicubic', 'mean': (0.5, 0.5, 0.5),