@ -6,16 +6,25 @@ from .helpers import load_pretrained
from . adaptive_avgmax_pool import *
from timm . data import IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
_models = [ ' inception_resnet_v2 ' ]
_models = [ ' inception_resnet_v2 ' , ' ens_adv_inception_resnet_v2 ' ]
__all__ = [ ' InceptionResnetV2 ' ] + _models
default_cfgs = {
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
' inception_resnet_v2 ' : {
' url ' : ' http ://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4 .pth' ,
' url ' : ' http s://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6 .pth' ,
' num_classes ' : 1001 , ' input_size ' : ( 3 , 299 , 299 ) , ' pool_size ' : ( 8 , 8 ) ,
' crop_pct ' : 0.8975 , ' interpolation ' : ' bicubic ' ,
' mean ' : IMAGENET_INCEPTION_MEAN , ' std ' : IMAGENET_INCEPTION_STD ,
' first_conv ' : ' conv2d_1a.conv ' , ' classifier ' : ' last_linear ' ,
' first_conv ' : ' conv2d_1a.conv ' , ' classifier ' : ' classif ' ,
} ,
# ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
' ens_adv_inception_resnet_v2 ' : {
' url ' : ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth ' ,
' num_classes ' : 1001 , ' input_size ' : ( 3 , 299 , 299 ) , ' pool_size ' : ( 8 , 8 ) ,
' crop_pct ' : 0.8975 , ' interpolation ' : ' bicubic ' ,
' mean ' : IMAGENET_INCEPTION_MEAN , ' std ' : IMAGENET_INCEPTION_STD ,
' first_conv ' : ' conv2d_1a.conv ' , ' classifier ' : ' classif ' ,
}
}
@ -274,19 +283,20 @@ class InceptionResnetV2(nn.Module):
)
self . block8 = Block8 ( noReLU = True )
self . conv2d_7b = BasicConv2d ( 2080 , self . num_features , kernel_size = 1 , stride = 1 )
self . last_linear = nn . Linear ( self . num_features , num_classes )
# NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
self . classif = nn . Linear ( self . num_features , num_classes )
def get_classifier ( self ) :
return self . last_linear
return self . classif
def reset_classifier ( self , num_classes , global_pool = ' avg ' ) :
self . global_pool = global_pool
self . num_classes = num_classes
del self . last_linear
del self . classif
if num_classes :
self . last_linear = torch . nn . Linear ( self . num_features , num_classes )
self . classif = torch . nn . Linear ( self . num_features , num_classes )
else :
self . last_linear = None
self . classif = None
def forward_features ( self , x , pool = True ) :
x = self . conv2d_1a ( x )
@ -314,13 +324,13 @@ class InceptionResnetV2(nn.Module):
x = self . forward_features ( x , pool = True )
if self . drop_rate > 0 :
x = F . dropout ( x , p = self . drop_rate , training = self . training )
x = self . last_linear ( x )
x = self . classif ( x )
return x
def inception_resnet_v2 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
r """ InceptionResnetV2 model architecture from the
` " InceptionV4, Inception-ResNet... " < https : / / arxiv . org / abs / 1602.07261 > ` _ paper .
` " InceptionV4, Inception-ResNet... " < https : / / arxiv . org / abs / 1602.07261 > ` paper .
"""
default_cfg = default_cfgs [ ' inception_resnet_v2 ' ]
model = InceptionResnetV2 ( num_classes = num_classes , in_chans = in_chans , * * kwargs )
@ -330,3 +340,16 @@ def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs
return model
def ens_adv_inception_resnet_v2 ( pretrained = False , num_classes = 1000 , in_chans = 3 , * * kwargs ) :
r """ Ensemble Adversarially trained InceptionResnetV2 model architecture
As per https : / / arxiv . org / abs / 1705.07204 and
https : / / github . com / tensorflow / models / tree / master / research / adv_imagenet_models .
"""
default_cfg = default_cfgs [ ' ens_adv_inception_resnet_v2 ' ]
model = InceptionResnetV2 ( num_classes = num_classes , in_chans = in_chans , * * kwargs )
model . default_cfg = default_cfg
if pretrained :
load_pretrained ( model , default_cfg , num_classes , in_chans )
return model