diff --git a/tests/test_models.py b/tests/test_models.py index 1f2417d4..b29fc639 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -69,6 +69,7 @@ def test_model_backward(model_name, batch_size): @pytest.mark.timeout(120) @pytest.mark.parametrize('model_name', list_models()) +#@pytest.mark.parametrize('model_name', ["xception41"]) @pytest.mark.parametrize('batch_size', [1]) def test_model_default_cfgs(model_name, batch_size): """Run a single forward pass with each model""" @@ -106,8 +107,8 @@ def test_model_default_cfgs(model_name, batch_size): assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] # check classifier and first convolution names match those in default_cfg - assert any([k.startswith(classifier) for k in state_dict.keys()]), f'{classifier} not in model params' - assert any([k.startswith(first_conv) for k in state_dict.keys()]), f'{first_conv} not in model params' + assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params' + assert first_conv + ".weight" in state_dict.keys(), f'{first_conv} not in model params' if 'GITHUB_ACTIONS' not in os.environ: diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index 25385c32..9bd99d04 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -28,22 +28,32 @@ default_cfgs = { 'gluon_resnet50_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth'), 'gluon_resnet101_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1b-3b017079.pth'), 'gluon_resnet152_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1b-c1edb0dd.pth'), - 'gluon_resnet50_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth'), - 'gluon_resnet101_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth'), - 'gluon_resnet152_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth'), - 'gluon_resnet50_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth'), - 'gluon_resnet101_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth'), - 'gluon_resnet152_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth'), - 'gluon_resnet50_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth'), - 'gluon_resnet101_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth'), - 'gluon_resnet152_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth'), + 'gluon_resnet50_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth', + first_conv='conv1.0'), + 'gluon_resnet50_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth', + first_conv='conv1.0'), + 'gluon_resnet50_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth', + first_conv='conv1.0'), + 'gluon_resnet101_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth', + first_conv='conv1.0'), + 'gluon_resnet152_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth', + first_conv='conv1.0'), 'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'), 'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'), 'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'), 'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'), 'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'), 'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'), - 'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth'), + 'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth', + first_conv='conv1.0'), } diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index 6634d4b3..aee1cccc 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -19,7 +19,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'Conv2d_1a_3x3', 'classifier': 'fc', + 'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc', **kwargs } diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 33b24e7b..5a8bb348 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -22,7 +22,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'conv1', 'classifier': 'fc', + 'first_conv': 'conv1.0', 'classifier': 'fc', **kwargs } diff --git a/timm/models/resnet.py b/timm/models/resnet.py index e2799d82..1e7a6b08 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -42,13 +42,15 @@ default_cfgs = { interpolation='bicubic'), 'resnet26d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), 'resnet50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', interpolation='bicubic'), 'resnet50d': _cfg( url='', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), 'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'), 'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'), 'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'), @@ -62,7 +64,8 @@ default_cfgs = { interpolation='bicubic'), 'resnext50d_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), 'resnext101_32x4d': _cfg(url=''), 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'), 'resnext101_64x4d': _cfg(url=''), @@ -118,7 +121,8 @@ default_cfgs = { interpolation='bicubic'), 'seresnet50tn': _cfg( url='', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), 'seresnet101': _cfg( url='', interpolation='bicubic'), @@ -132,13 +136,16 @@ default_cfgs = { interpolation='bicubic'), 'seresnext26d_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), 'seresnext26t_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26t_32x4d-361bc1c4.pth', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), 'seresnext26tn_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), 'seresnext50_32x4d': _cfg( interpolation='bicubic'), 'seresnext101_32x4d': _cfg( @@ -149,7 +156,8 @@ default_cfgs = { interpolation='bicubic'), 'senet154': _cfg( url='', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), # Efficient Channel Attention ResNets 'ecaresnet18': _cfg(), @@ -159,21 +167,26 @@ default_cfgs = { interpolation='bicubic'), 'ecaresnet50d': _cfg( url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), 'ecaresnet50d_pruned': _cfg( url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), 'ecaresnet101d': _cfg( url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), 'ecaresnet101d_pruned': _cfg( url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), # Efficient Channel Attention ResNeXts 'ecaresnext26tn_32x4d': _cfg( url='', - interpolation='bicubic'), + interpolation='bicubic', + first_conv='conv1.0'), 'ecaresnext50_32x4d': _cfg( url='', interpolation='bicubic'), diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 815aec06..73bc7732 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -29,7 +29,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem', 'classifier': 'fc', + 'first_conv': 'stem.0', 'classifier': 'fc', **kwargs } diff --git a/timm/models/sknet.py b/timm/models/sknet.py index ddd22a35..6c654922 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -36,7 +36,8 @@ default_cfgs = { 'skresnet34': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'), 'skresnet50': _cfg(), - 'skresnet50d': _cfg(), + 'skresnet50d': _cfg( + first_conv='conv1.0'), 'skresnext50_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth'), } diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 50fc6f48..e060c20d 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -25,7 +25,7 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': (0, 0, 0), 'std': (1, 1, 1), - 'first_conv': 'body.conv1', 'classifier': 'head.fc', + 'first_conv': 'body.conv1.0', 'classifier': 'head.fc', **kwargs } diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 973a72da..f3a4a50a 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -25,7 +25,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10), 'crop_pct': 0.903, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'stem.0', 'classifier': 'head.fc', + 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', **kwargs }