Some pretrianed URL changes

* host some of Cadene's weights on github instead of .fr for speed
* add my old port of ensemble adversarial inception resnet v2
* switch to my TF port of normal inception res v2 and change FC layer back to 'classif' for compat with ens_adv
pull/16/head
Ross Wightman 6 years ago
parent 827a3d6010
commit 87b92c528e

@ -6,16 +6,25 @@ from .helpers import load_pretrained
from .adaptive_avgmax_pool import *
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
_models = ['inception_resnet_v2']
_models = ['inception_resnet_v2', 'ens_adv_inception_resnet_v2']
__all__ = ['InceptionResnetV2'] + _models
default_cfgs = {
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
'inception_resnet_v2': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.8975, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear',
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
},
# ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
'ens_adv_inception_resnet_v2': {
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.8975, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
}
}
@ -274,19 +283,20 @@ class InceptionResnetV2(nn.Module):
)
self.block8 = Block8(noReLU=True)
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
self.last_linear = nn.Linear(self.num_features, num_classes)
# NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
self.classif = nn.Linear(self.num_features, num_classes)
def get_classifier(self):
return self.last_linear
return self.classif
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = global_pool
self.num_classes = num_classes
del self.last_linear
del self.classif
if num_classes:
self.last_linear = torch.nn.Linear(self.num_features, num_classes)
self.classif = torch.nn.Linear(self.num_features, num_classes)
else:
self.last_linear = None
self.classif = None
def forward_features(self, x, pool=True):
x = self.conv2d_1a(x)
@ -314,13 +324,13 @@ class InceptionResnetV2(nn.Module):
x = self.forward_features(x, pool=True)
if self.drop_rate > 0:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.last_linear(x)
x = self.classif(x)
return x
def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
r"""InceptionResnetV2 model architecture from the
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
"""
default_cfg = default_cfgs['inception_resnet_v2']
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
@ -330,3 +340,16 @@ def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs
return model
def ens_adv_inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
As per https://arxiv.org/abs/1705.07204 and
https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
"""
default_cfg = default_cfgs['ens_adv_inception_resnet_v2']
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

@ -11,7 +11,7 @@ __all__ = ['InceptionV4'] + _models
default_cfgs = {
'inception_v4': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,

@ -20,7 +20,7 @@ __all__ = ['PNASNet5Large'] + _models
default_cfgs = {
'pnasnet5large': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth',
'input_size': (3, 331, 331),
'pool_size': (11, 11),
'crop_pct': 0.875,

@ -37,20 +37,20 @@ def _cfg(url='', **kwargs):
default_cfgs = {
'senet154':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'),
'seresnet18':
_cfg(url='https://www.dropbox.com/s/3o3nd8mfhxod7rq/seresnet18-4bb0ce65.pth?dl=1',
interpolation='bicubic'),
'seresnet34':
_cfg(url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'),
'seresnet50':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth'),
'seresnet101':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth'),
'seresnet152':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth'),
'seresnext26_32x4d':
_cfg(url='https://www.dropbox.com/s/zaeruz2bejcdhh3/seresnext26_32x4d-65ebdb501.pth?dl=1',
interpolation='bicubic'),
'seresnet18': _cfg(
url='https://www.dropbox.com/s/3o3nd8mfhxod7rq/seresnet18-4bb0ce65.pth?dl=1',
interpolation='bicubic'),
'seresnet34': _cfg(
url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'),
'seresnet50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'),
'seresnet101': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'),
'seresnet152': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'),
'seresnext26_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth',
interpolation='bicubic'),
'seresnext50_32x4d':
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'),
'seresnext101_32x4d':

@ -35,7 +35,7 @@ __all__ = ['Xception'] + _models
default_cfgs = {
'xception': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth',
'input_size': (3, 299, 299),
'crop_pct': 0.8975,
'interpolation': 'bicubic',

Loading…
Cancel
Save