diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index 4b884a59..29c68a8b 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -6,16 +6,25 @@ from .helpers import load_pretrained from .adaptive_avgmax_pool import * from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -_models = ['inception_resnet_v2'] +_models = ['inception_resnet_v2', 'ens_adv_inception_resnet_v2'] __all__ = ['InceptionResnetV2'] + _models default_cfgs = { + # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 'inception_resnet_v2': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth', + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth', 'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), 'crop_pct': 0.8975, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear', + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', + }, + # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz + 'ens_adv_inception_resnet_v2': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth', + 'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'crop_pct': 0.8975, 'interpolation': 'bicubic', + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif', } } @@ -274,19 +283,20 @@ class InceptionResnetV2(nn.Module): ) self.block8 = Block8(noReLU=True) self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1) - self.last_linear = nn.Linear(self.num_features, num_classes) + # NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC + self.classif = nn.Linear(self.num_features, num_classes) def get_classifier(self): - return self.last_linear + return self.classif def reset_classifier(self, num_classes, global_pool='avg'): self.global_pool = global_pool self.num_classes = num_classes - del self.last_linear + del self.classif if num_classes: - self.last_linear = torch.nn.Linear(self.num_features, num_classes) + self.classif = torch.nn.Linear(self.num_features, num_classes) else: - self.last_linear = None + self.classif = None def forward_features(self, x, pool=True): x = self.conv2d_1a(x) @@ -314,13 +324,13 @@ class InceptionResnetV2(nn.Module): x = self.forward_features(x, pool=True) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.last_linear(x) + x = self.classif(x) return x def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): r"""InceptionResnetV2 model architecture from the - `"InceptionV4, Inception-ResNet..." `_ paper. + `"InceptionV4, Inception-ResNet..." ` paper. """ default_cfg = default_cfgs['inception_resnet_v2'] model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs) @@ -330,3 +340,16 @@ def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs return model + +def ens_adv_inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + r""" Ensemble Adversarially trained InceptionResnetV2 model architecture + As per https://arxiv.org/abs/1705.07204 and + https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models. + """ + default_cfg = default_cfgs['ens_adv_inception_resnet_v2'] + model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + + return model diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index f25f6515..ac819cfe 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -11,7 +11,7 @@ __all__ = ['InceptionV4'] + _models default_cfgs = { 'inception_v4': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth', + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth', 'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 76a63590..c4b25820 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -20,7 +20,7 @@ __all__ = ['PNASNet5Large'] + _models default_cfgs = { 'pnasnet5large': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth', 'input_size': (3, 331, 331), 'pool_size': (11, 11), 'crop_pct': 0.875, diff --git a/timm/models/senet.py b/timm/models/senet.py index 31841e7f..22283116 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -37,20 +37,20 @@ def _cfg(url='', **kwargs): default_cfgs = { 'senet154': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'), - 'seresnet18': - _cfg(url='https://www.dropbox.com/s/3o3nd8mfhxod7rq/seresnet18-4bb0ce65.pth?dl=1', - interpolation='bicubic'), - 'seresnet34': - _cfg(url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'), - 'seresnet50': - _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth'), - 'seresnet101': - _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth'), - 'seresnet152': - _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth'), - 'seresnext26_32x4d': - _cfg(url='https://www.dropbox.com/s/zaeruz2bejcdhh3/seresnext26_32x4d-65ebdb501.pth?dl=1', - interpolation='bicubic'), + 'seresnet18': _cfg( + url='https://www.dropbox.com/s/3o3nd8mfhxod7rq/seresnet18-4bb0ce65.pth?dl=1', + interpolation='bicubic'), + 'seresnet34': _cfg( + url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'), + 'seresnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'), + 'seresnet101': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'), + 'seresnet152': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'), + 'seresnext26_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth', + interpolation='bicubic'), 'seresnext50_32x4d': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'), 'seresnext101_32x4d': diff --git a/timm/models/xception.py b/timm/models/xception.py index de536408..a2d63b6e 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -35,7 +35,7 @@ __all__ = ['Xception'] + _models default_cfgs = { 'xception': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth', + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth', 'input_size': (3, 299, 299), 'crop_pct': 0.8975, 'interpolation': 'bicubic',