import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg from .registry import register_model from .layers import trunc_normal_, SelectAdaptivePool2d def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'Conv2d_1a_3x3', 'classifier': 'fc', **kwargs } default_cfgs = { # original PyTorch weights, ported from Tensorflow but modified 'inception_v3': _cfg( url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', has_aux=True), # checkpoint has aux logit layer weights # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) 'tf_inception_v3': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', num_classes=1001, has_aux=False), # my port of Tensorflow adversarially trained Inception V3 from # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz 'adv_inception_v3': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', num_classes=1001, has_aux=False), # from gluon pretrained models, best performing in terms of accuracy/loss metrics # https://gluon-cv.mxnet.io/model_zoo/classification.html 'gluon_inception_v3': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth', mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults std=IMAGENET_DEFAULT_STD, # also works well with inception defaults has_aux=False, ) } class InceptionA(nn.Module): def __init__(self, in_channels, pool_features, conv_block=None): super(InceptionA, self).__init__() if conv_block is None: conv_block = BasicConv2d self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) def _forward(self, x): branch1x1 = self.branch1x1(x) branch5x5 = self.branch5x5_1(x) branch5x5 = self.branch5x5_2(branch5x5) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] return outputs def forward(self, x): outputs = self._forward(x) return torch.cat(outputs, 1) class InceptionB(nn.Module): def __init__(self, in_channels, conv_block=None): super(InceptionB, self).__init__() if conv_block is None: conv_block = BasicConv2d self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) def _forward(self, x): branch3x3 = self.branch3x3(x) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) outputs = [branch3x3, branch3x3dbl, branch_pool] return outputs def forward(self, x): outputs = self._forward(x) return torch.cat(outputs, 1) class InceptionC(nn.Module): def __init__(self, in_channels, channels_7x7, conv_block=None): super(InceptionC, self).__init__() if conv_block is None: conv_block = BasicConv2d self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) c7 = channels_7x7 self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) self.branch_pool = conv_block(in_channels, 192, kernel_size=1) def _forward(self, x): branch1x1 = self.branch1x1(x) branch7x7 = self.branch7x7_1(x) branch7x7 = self.branch7x7_2(branch7x7) branch7x7 = self.branch7x7_3(branch7x7) branch7x7dbl = self.branch7x7dbl_1(x) branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] return outputs def forward(self, x): outputs = self._forward(x) return torch.cat(outputs, 1) class InceptionD(nn.Module): def __init__(self, in_channels, conv_block=None): super(InceptionD, self).__init__() if conv_block is None: conv_block = BasicConv2d self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) def _forward(self, x): branch3x3 = self.branch3x3_1(x) branch3x3 = self.branch3x3_2(branch3x3) branch7x7x3 = self.branch7x7x3_1(x) branch7x7x3 = self.branch7x7x3_2(branch7x7x3) branch7x7x3 = self.branch7x7x3_3(branch7x7x3) branch7x7x3 = self.branch7x7x3_4(branch7x7x3) branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) outputs = [branch3x3, branch7x7x3, branch_pool] return outputs def forward(self, x): outputs = self._forward(x) return torch.cat(outputs, 1) class InceptionE(nn.Module): def __init__(self, in_channels, conv_block=None): super(InceptionE, self).__init__() if conv_block is None: conv_block = BasicConv2d self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) self.branch_pool = conv_block(in_channels, 192, kernel_size=1) def _forward(self, x): branch1x1 = self.branch1x1(x) branch3x3 = self.branch3x3_1(x) branch3x3 = [ self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3), ] branch3x3 = torch.cat(branch3x3, 1) branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl = [ self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl), ] branch3x3dbl = torch.cat(branch3x3dbl, 1) branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) branch_pool = self.branch_pool(branch_pool) outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] return outputs def forward(self, x): outputs = self._forward(x) return torch.cat(outputs, 1) class InceptionAux(nn.Module): def __init__(self, in_channels, num_classes, conv_block=None): super(InceptionAux, self).__init__() if conv_block is None: conv_block = BasicConv2d self.conv0 = conv_block(in_channels, 128, kernel_size=1) self.conv1 = conv_block(128, 768, kernel_size=5) self.conv1.stddev = 0.01 self.fc = nn.Linear(768, num_classes) self.fc.stddev = 0.001 def forward(self, x): # N x 768 x 17 x 17 x = F.avg_pool2d(x, kernel_size=5, stride=3) # N x 768 x 5 x 5 x = self.conv0(x) # N x 128 x 5 x 5 x = self.conv1(x) # N x 768 x 1 x 1 # Adaptive average pooling x = F.adaptive_avg_pool2d(x, (1, 1)) # N x 768 x 1 x 1 x = torch.flatten(x, 1) # N x 768 x = self.fc(x) # N x 1000 return x class BasicConv2d(nn.Module): def __init__(self, in_channels, out_channels, **kwargs): super(BasicConv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.bn = nn.BatchNorm2d(out_channels, eps=0.001) def forward(self, x): x = self.conv(x) x = self.bn(x) return F.relu(x, inplace=True) class InceptionV3(nn.Module): """Inception-V3 with no AuxLogits FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns """ def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=False): super(InceptionV3, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate self.aux_logits = aux_logits self.Conv2d_1a_3x3 = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2) self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2) self.Mixed_5b = InceptionA(192, pool_features=32) self.Mixed_5c = InceptionA(256, pool_features=64) self.Mixed_5d = InceptionA(288, pool_features=64) self.Mixed_6a = InceptionB(288) self.Mixed_6b = InceptionC(768, channels_7x7=128) self.Mixed_6c = InceptionC(768, channels_7x7=160) self.Mixed_6d = InceptionC(768, channels_7x7=160) self.Mixed_6e = InceptionC(768, channels_7x7=192) if aux_logits: self.AuxLogits = InceptionAux(768, num_classes) else: self.AuxLogits = None self.Mixed_7a = InceptionD(768) self.Mixed_7b = InceptionE(1280) self.Mixed_7c = InceptionE(2048) self.feature_info = [ dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'), dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'), dict(num_chs=288, reduction=8, module='Mixed_5d'), dict(num_chs=768, reduction=16, module='Mixed_6e'), dict(num_chs=2048, reduction=32, module='Mixed_7c'), ] self.num_features = 2048 self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.fc = nn.Linear(2048, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): stddev = m.stddev if hasattr(m, 'stddev') else 0.1 trunc_normal_(m.weight, std=stddev) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward_preaux(self, x): # N x 3 x 299 x 299 x = self.Conv2d_1a_3x3(x) # N x 32 x 149 x 149 x = self.Conv2d_2a_3x3(x) # N x 32 x 147 x 147 x = self.Conv2d_2b_3x3(x) # N x 64 x 147 x 147 x = self.Pool1(x) # N x 64 x 73 x 73 x = self.Conv2d_3b_1x1(x) # N x 80 x 73 x 73 x = self.Conv2d_4a_3x3(x) # N x 192 x 71 x 71 x = self.Pool2(x) # N x 192 x 35 x 35 x = self.Mixed_5b(x) # N x 256 x 35 x 35 x = self.Mixed_5c(x) # N x 288 x 35 x 35 x = self.Mixed_5d(x) # N x 288 x 35 x 35 x = self.Mixed_6a(x) # N x 768 x 17 x 17 x = self.Mixed_6b(x) # N x 768 x 17 x 17 x = self.Mixed_6c(x) # N x 768 x 17 x 17 x = self.Mixed_6d(x) # N x 768 x 17 x 17 x = self.Mixed_6e(x) # N x 768 x 17 x 17 return x def forward_postaux(self, x): x = self.Mixed_7a(x) # N x 1280 x 8 x 8 x = self.Mixed_7b(x) # N x 2048 x 8 x 8 x = self.Mixed_7c(x) # N x 2048 x 8 x 8 return x def forward_features(self, x): x = self.forward_preaux(x) x = self.forward_postaux(x) return x def get_classifier(self): return self.fc def reset_classifier(self, num_classes, global_pool='avg'): self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_classes = num_classes if self.num_classes > 0: self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) else: self.fc = nn.Identity() def forward(self, x): x = self.forward_features(x) x = self.global_pool(x).flatten(1) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) return x class InceptionV3Aux(InceptionV3): """InceptionV3 with AuxLogits """ def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=True): super(InceptionV3Aux, self).__init__( num_classes, in_chans, drop_rate, global_pool, aux_logits) def forward_features(self, x): x = self.forward_preaux(x) aux = self.AuxLogits(x) if self.training else None x = self.forward_postaux(x) return x, aux def forward(self, x): x, aux = self.forward_features(x) x = self.global_pool(x).flatten(1) if self.drop_rate > 0: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) return x, aux def _create_inception_v3(variant, pretrained=False, **kwargs): default_cfg = default_cfgs[variant] aux_logits = kwargs.pop('aux_logits', False) if aux_logits: assert not kwargs.pop('features_only', False) model_cls = InceptionV3Aux load_strict = default_cfg['has_aux'] else: model_cls = InceptionV3 load_strict = not default_cfg['has_aux'] return build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_strict=load_strict, **kwargs) @register_model def inception_v3(pretrained=False, **kwargs): # original PyTorch weights, ported from Tensorflow but modified model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs) return model @register_model def tf_inception_v3(pretrained=False, **kwargs): # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) model = _create_inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs) return model @register_model def adv_inception_v3(pretrained=False, **kwargs): # my port of Tensorflow adversarially trained Inception V3 from # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz model = _create_inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs) return model @register_model def gluon_inception_v3(pretrained=False, **kwargs): # from gluon pretrained models, best performing in terms of accuracy/loss metrics # https://gluon-cv.mxnet.io/model_zoo/classification.html model = _create_inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs) return model