Merge pull request #146 from rwightman/inceptionv3_fix
Remove annoying torchvision InceptionV3 dependency on scipy and insanely slow tru…pull/148/head
commit
63addb741f
@ -1,120 +1,562 @@
|
|||||||
from torchvision.models import Inception3
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
|
from .layers import trunc_normal_, SelectAdaptivePool2d
|
||||||
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||||
|
'first_conv': 'conv1', 'classifier': 'fc',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
default_cfgs = {
|
default_cfgs = {
|
||||||
# original PyTorch weights, ported from Tensorflow but modified
|
# original PyTorch weights, ported from Tensorflow but modified
|
||||||
'inception_v3': {
|
'inception_v3': _cfg(
|
||||||
'url': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
|
url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
|
||||||
'input_size': (3, 299, 299),
|
has_aux=True), # checkpoint has aux logit layer weights
|
||||||
'crop_pct': 0.875,
|
|
||||||
'interpolation': 'bicubic',
|
|
||||||
'mean': IMAGENET_INCEPTION_MEAN, # also works well enough with resnet defaults
|
|
||||||
'std': IMAGENET_INCEPTION_STD, # also works well enough with resnet defaults
|
|
||||||
'num_classes': 1000,
|
|
||||||
'first_conv': 'conv0',
|
|
||||||
'classifier': 'fc'
|
|
||||||
},
|
|
||||||
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
||||||
'tf_inception_v3': {
|
'tf_inception_v3': _cfg(
|
||||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
|
||||||
'input_size': (3, 299, 299),
|
num_classes=1001, has_aux=False),
|
||||||
'crop_pct': 0.875,
|
|
||||||
'interpolation': 'bicubic',
|
|
||||||
'mean': IMAGENET_INCEPTION_MEAN,
|
|
||||||
'std': IMAGENET_INCEPTION_STD,
|
|
||||||
'num_classes': 1001,
|
|
||||||
'first_conv': 'conv0',
|
|
||||||
'classifier': 'fc'
|
|
||||||
},
|
|
||||||
# my port of Tensorflow adversarially trained Inception V3 from
|
# my port of Tensorflow adversarially trained Inception V3 from
|
||||||
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
||||||
'adv_inception_v3': {
|
'adv_inception_v3': _cfg(
|
||||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
|
||||||
'input_size': (3, 299, 299),
|
num_classes=1001, has_aux=False),
|
||||||
'crop_pct': 0.875,
|
|
||||||
'interpolation': 'bicubic',
|
|
||||||
'mean': IMAGENET_INCEPTION_MEAN,
|
|
||||||
'std': IMAGENET_INCEPTION_STD,
|
|
||||||
'num_classes': 1001,
|
|
||||||
'first_conv': 'conv0',
|
|
||||||
'classifier': 'fc'
|
|
||||||
},
|
|
||||||
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
||||||
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
||||||
'gluon_inception_v3': {
|
'gluon_inception_v3': _cfg(
|
||||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth',
|
||||||
'input_size': (3, 299, 299),
|
mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
|
||||||
'crop_pct': 0.875,
|
std=IMAGENET_DEFAULT_STD, # also works well with inception defaults
|
||||||
'interpolation': 'bicubic',
|
has_aux=False,
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
|
)
|
||||||
'std': IMAGENET_DEFAULT_STD, # also works well with inception defaults
|
|
||||||
'num_classes': 1000,
|
|
||||||
'first_conv': 'conv0',
|
|
||||||
'classifier': 'fc'
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _assert_default_kwargs(kwargs):
|
class InceptionV3Aux(nn.Module):
|
||||||
# for imported models (ie torchvision) without capability to change these params,
|
"""InceptionV3 with AuxLogits
|
||||||
# make sure they aren't being set to non-defaults
|
"""
|
||||||
assert kwargs.pop('global_pool', 'avg') == 'avg'
|
|
||||||
assert kwargs.pop('drop_rate', 0.) == 0.
|
def __init__(self, inception_blocks=None, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'):
|
||||||
|
super(InceptionV3Aux, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
|
||||||
|
if inception_blocks is None:
|
||||||
|
inception_blocks = [
|
||||||
|
BasicConv2d, InceptionA, InceptionB, InceptionC,
|
||||||
|
InceptionD, InceptionE, InceptionAux
|
||||||
|
]
|
||||||
|
assert len(inception_blocks) == 7
|
||||||
|
conv_block = inception_blocks[0]
|
||||||
|
inception_a = inception_blocks[1]
|
||||||
|
inception_b = inception_blocks[2]
|
||||||
|
inception_c = inception_blocks[3]
|
||||||
|
inception_d = inception_blocks[4]
|
||||||
|
inception_e = inception_blocks[5]
|
||||||
|
inception_aux = inception_blocks[6]
|
||||||
|
|
||||||
|
self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2)
|
||||||
|
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
|
||||||
|
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
|
||||||
|
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
|
||||||
|
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
|
||||||
|
self.Mixed_5b = inception_a(192, pool_features=32)
|
||||||
|
self.Mixed_5c = inception_a(256, pool_features=64)
|
||||||
|
self.Mixed_5d = inception_a(288, pool_features=64)
|
||||||
|
self.Mixed_6a = inception_b(288)
|
||||||
|
self.Mixed_6b = inception_c(768, channels_7x7=128)
|
||||||
|
self.Mixed_6c = inception_c(768, channels_7x7=160)
|
||||||
|
self.Mixed_6d = inception_c(768, channels_7x7=160)
|
||||||
|
self.Mixed_6e = inception_c(768, channels_7x7=192)
|
||||||
|
self.AuxLogits = inception_aux(768, num_classes)
|
||||||
|
self.Mixed_7a = inception_d(768)
|
||||||
|
self.Mixed_7b = inception_e(1280)
|
||||||
|
self.Mixed_7c = inception_e(2048)
|
||||||
|
|
||||||
|
self.num_features = 2048
|
||||||
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
|
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||||
|
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
|
||||||
|
trunc_normal_(m.weight, std=stddev)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
# N x 3 x 299 x 299
|
||||||
|
x = self.Conv2d_1a_3x3(x)
|
||||||
|
# N x 32 x 149 x 149
|
||||||
|
x = self.Conv2d_2a_3x3(x)
|
||||||
|
# N x 32 x 147 x 147
|
||||||
|
x = self.Conv2d_2b_3x3(x)
|
||||||
|
# N x 64 x 147 x 147
|
||||||
|
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||||
|
# N x 64 x 73 x 73
|
||||||
|
x = self.Conv2d_3b_1x1(x)
|
||||||
|
# N x 80 x 73 x 73
|
||||||
|
x = self.Conv2d_4a_3x3(x)
|
||||||
|
# N x 192 x 71 x 71
|
||||||
|
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||||
|
# N x 192 x 35 x 35
|
||||||
|
x = self.Mixed_5b(x)
|
||||||
|
# N x 256 x 35 x 35
|
||||||
|
x = self.Mixed_5c(x)
|
||||||
|
# N x 288 x 35 x 35
|
||||||
|
x = self.Mixed_5d(x)
|
||||||
|
# N x 288 x 35 x 35
|
||||||
|
x = self.Mixed_6a(x)
|
||||||
|
# N x 768 x 17 x 17
|
||||||
|
x = self.Mixed_6b(x)
|
||||||
|
# N x 768 x 17 x 17
|
||||||
|
x = self.Mixed_6c(x)
|
||||||
|
# N x 768 x 17 x 17
|
||||||
|
x = self.Mixed_6d(x)
|
||||||
|
# N x 768 x 17 x 17
|
||||||
|
x = self.Mixed_6e(x)
|
||||||
|
# N x 768 x 17 x 17
|
||||||
|
aux = self.AuxLogits(x) if self.training else None
|
||||||
|
# N x 768 x 17 x 17
|
||||||
|
x = self.Mixed_7a(x)
|
||||||
|
# N x 1280 x 8 x 8
|
||||||
|
x = self.Mixed_7b(x)
|
||||||
|
# N x 2048 x 8 x 8
|
||||||
|
x = self.Mixed_7c(x)
|
||||||
|
# N x 2048 x 8 x 8
|
||||||
|
return x, aux
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
|
self.num_classes = num_classes
|
||||||
|
if self.num_classes > 0:
|
||||||
|
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||||
|
else:
|
||||||
|
self.fc = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, aux = self.forward_features(x)
|
||||||
|
x = self.global_pool(x).flatten(1)
|
||||||
|
if self.drop_rate > 0:
|
||||||
|
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x, aux
|
||||||
|
|
||||||
|
|
||||||
|
class InceptionV3(nn.Module):
|
||||||
|
"""Inception-V3 with no AuxLogits
|
||||||
|
FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, inception_blocks=None, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'):
|
||||||
|
super(InceptionV3, self).__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
|
||||||
|
if inception_blocks is None:
|
||||||
|
inception_blocks = [
|
||||||
|
BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE]
|
||||||
|
assert len(inception_blocks) >= 6
|
||||||
|
conv_block = inception_blocks[0]
|
||||||
|
inception_a = inception_blocks[1]
|
||||||
|
inception_b = inception_blocks[2]
|
||||||
|
inception_c = inception_blocks[3]
|
||||||
|
inception_d = inception_blocks[4]
|
||||||
|
inception_e = inception_blocks[5]
|
||||||
|
|
||||||
|
self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2)
|
||||||
|
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
|
||||||
|
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
|
||||||
|
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
|
||||||
|
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
|
||||||
|
self.Mixed_5b = inception_a(192, pool_features=32)
|
||||||
|
self.Mixed_5c = inception_a(256, pool_features=64)
|
||||||
|
self.Mixed_5d = inception_a(288, pool_features=64)
|
||||||
|
self.Mixed_6a = inception_b(288)
|
||||||
|
self.Mixed_6b = inception_c(768, channels_7x7=128)
|
||||||
|
self.Mixed_6c = inception_c(768, channels_7x7=160)
|
||||||
|
self.Mixed_6d = inception_c(768, channels_7x7=160)
|
||||||
|
self.Mixed_6e = inception_c(768, channels_7x7=192)
|
||||||
|
self.Mixed_7a = inception_d(768)
|
||||||
|
self.Mixed_7b = inception_e(1280)
|
||||||
|
self.Mixed_7c = inception_e(2048)
|
||||||
|
|
||||||
|
self.num_features = 2048
|
||||||
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
|
self.fc = nn.Linear(2048, num_classes)
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||||
|
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
|
||||||
|
trunc_normal_(m.weight, std=stddev)
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
# N x 3 x 299 x 299
|
||||||
|
x = self.Conv2d_1a_3x3(x)
|
||||||
|
# N x 32 x 149 x 149
|
||||||
|
x = self.Conv2d_2a_3x3(x)
|
||||||
|
# N x 32 x 147 x 147
|
||||||
|
x = self.Conv2d_2b_3x3(x)
|
||||||
|
# N x 64 x 147 x 147
|
||||||
|
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||||
|
# N x 64 x 73 x 73
|
||||||
|
x = self.Conv2d_3b_1x1(x)
|
||||||
|
# N x 80 x 73 x 73
|
||||||
|
x = self.Conv2d_4a_3x3(x)
|
||||||
|
# N x 192 x 71 x 71
|
||||||
|
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||||
|
# N x 192 x 35 x 35
|
||||||
|
x = self.Mixed_5b(x)
|
||||||
|
# N x 256 x 35 x 35
|
||||||
|
x = self.Mixed_5c(x)
|
||||||
|
# N x 288 x 35 x 35
|
||||||
|
x = self.Mixed_5d(x)
|
||||||
|
# N x 288 x 35 x 35
|
||||||
|
x = self.Mixed_6a(x)
|
||||||
|
# N x 768 x 17 x 17
|
||||||
|
x = self.Mixed_6b(x)
|
||||||
|
# N x 768 x 17 x 17
|
||||||
|
x = self.Mixed_6c(x)
|
||||||
|
# N x 768 x 17 x 17
|
||||||
|
x = self.Mixed_6d(x)
|
||||||
|
# N x 768 x 17 x 17
|
||||||
|
x = self.Mixed_6e(x)
|
||||||
|
# N x 768 x 17 x 17
|
||||||
|
x = self.Mixed_7a(x)
|
||||||
|
# N x 1280 x 8 x 8
|
||||||
|
x = self.Mixed_7b(x)
|
||||||
|
# N x 2048 x 8 x 8
|
||||||
|
x = self.Mixed_7c(x)
|
||||||
|
# N x 2048 x 8 x 8
|
||||||
|
return x
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
|
self.num_classes = num_classes
|
||||||
|
if self.num_classes > 0:
|
||||||
|
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||||
|
else:
|
||||||
|
self.fc = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.global_pool(x).flatten(1)
|
||||||
|
if self.drop_rate > 0:
|
||||||
|
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class InceptionA(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, pool_features, conv_block=None):
|
||||||
|
super(InceptionA, self).__init__()
|
||||||
|
if conv_block is None:
|
||||||
|
conv_block = BasicConv2d
|
||||||
|
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
|
||||||
|
|
||||||
|
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
|
||||||
|
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
|
||||||
|
|
||||||
|
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
||||||
|
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
||||||
|
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
branch1x1 = self.branch1x1(x)
|
||||||
|
|
||||||
|
branch5x5 = self.branch5x5_1(x)
|
||||||
|
branch5x5 = self.branch5x5_2(branch5x5)
|
||||||
|
|
||||||
|
branch3x3dbl = self.branch3x3dbl_1(x)
|
||||||
|
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
||||||
|
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
||||||
|
|
||||||
|
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
||||||
|
branch_pool = self.branch_pool(branch_pool)
|
||||||
|
|
||||||
|
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
outputs = self._forward(x)
|
||||||
|
return torch.cat(outputs, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class InceptionB(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, conv_block=None):
|
||||||
|
super(InceptionB, self).__init__()
|
||||||
|
if conv_block is None:
|
||||||
|
conv_block = BasicConv2d
|
||||||
|
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
|
||||||
|
|
||||||
|
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
||||||
|
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
||||||
|
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
branch3x3 = self.branch3x3(x)
|
||||||
|
|
||||||
|
branch3x3dbl = self.branch3x3dbl_1(x)
|
||||||
|
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
||||||
|
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
||||||
|
|
||||||
|
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||||
|
|
||||||
|
outputs = [branch3x3, branch3x3dbl, branch_pool]
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
outputs = self._forward(x)
|
||||||
|
return torch.cat(outputs, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class InceptionC(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, channels_7x7, conv_block=None):
|
||||||
|
super(InceptionC, self).__init__()
|
||||||
|
if conv_block is None:
|
||||||
|
conv_block = BasicConv2d
|
||||||
|
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
|
||||||
|
|
||||||
|
c7 = channels_7x7
|
||||||
|
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
|
||||||
|
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
||||||
|
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
|
||||||
|
|
||||||
|
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
|
||||||
|
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
||||||
|
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
||||||
|
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
||||||
|
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
||||||
|
|
||||||
|
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
branch1x1 = self.branch1x1(x)
|
||||||
|
|
||||||
|
branch7x7 = self.branch7x7_1(x)
|
||||||
|
branch7x7 = self.branch7x7_2(branch7x7)
|
||||||
|
branch7x7 = self.branch7x7_3(branch7x7)
|
||||||
|
|
||||||
|
branch7x7dbl = self.branch7x7dbl_1(x)
|
||||||
|
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
||||||
|
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
||||||
|
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
||||||
|
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
||||||
|
|
||||||
|
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
||||||
|
branch_pool = self.branch_pool(branch_pool)
|
||||||
|
|
||||||
|
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
outputs = self._forward(x)
|
||||||
|
return torch.cat(outputs, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class InceptionD(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, conv_block=None):
|
||||||
|
super(InceptionD, self).__init__()
|
||||||
|
if conv_block is None:
|
||||||
|
conv_block = BasicConv2d
|
||||||
|
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
||||||
|
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
|
||||||
|
|
||||||
|
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
||||||
|
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
|
||||||
|
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
|
||||||
|
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
branch3x3 = self.branch3x3_1(x)
|
||||||
|
branch3x3 = self.branch3x3_2(branch3x3)
|
||||||
|
|
||||||
|
branch7x7x3 = self.branch7x7x3_1(x)
|
||||||
|
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
|
||||||
|
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
|
||||||
|
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
|
||||||
|
|
||||||
|
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||||
|
outputs = [branch3x3, branch7x7x3, branch_pool]
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
outputs = self._forward(x)
|
||||||
|
return torch.cat(outputs, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class InceptionE(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, conv_block=None):
|
||||||
|
super(InceptionE, self).__init__()
|
||||||
|
if conv_block is None:
|
||||||
|
conv_block = BasicConv2d
|
||||||
|
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
|
||||||
|
|
||||||
|
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
|
||||||
|
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
||||||
|
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
||||||
|
|
||||||
|
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
|
||||||
|
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
|
||||||
|
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
||||||
|
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
||||||
|
|
||||||
|
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
branch1x1 = self.branch1x1(x)
|
||||||
|
|
||||||
|
branch3x3 = self.branch3x3_1(x)
|
||||||
|
branch3x3 = [
|
||||||
|
self.branch3x3_2a(branch3x3),
|
||||||
|
self.branch3x3_2b(branch3x3),
|
||||||
|
]
|
||||||
|
branch3x3 = torch.cat(branch3x3, 1)
|
||||||
|
|
||||||
|
branch3x3dbl = self.branch3x3dbl_1(x)
|
||||||
|
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
||||||
|
branch3x3dbl = [
|
||||||
|
self.branch3x3dbl_3a(branch3x3dbl),
|
||||||
|
self.branch3x3dbl_3b(branch3x3dbl),
|
||||||
|
]
|
||||||
|
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
||||||
|
|
||||||
|
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
||||||
|
branch_pool = self.branch_pool(branch_pool)
|
||||||
|
|
||||||
|
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
outputs = self._forward(x)
|
||||||
|
return torch.cat(outputs, 1)
|
||||||
|
|
||||||
|
|
||||||
|
class InceptionAux(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, num_classes, conv_block=None):
|
||||||
|
super(InceptionAux, self).__init__()
|
||||||
|
if conv_block is None:
|
||||||
|
conv_block = BasicConv2d
|
||||||
|
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
|
||||||
|
self.conv1 = conv_block(128, 768, kernel_size=5)
|
||||||
|
self.conv1.stddev = 0.01
|
||||||
|
self.fc = nn.Linear(768, num_classes)
|
||||||
|
self.fc.stddev = 0.001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# N x 768 x 17 x 17
|
||||||
|
x = F.avg_pool2d(x, kernel_size=5, stride=3)
|
||||||
|
# N x 768 x 5 x 5
|
||||||
|
x = self.conv0(x)
|
||||||
|
# N x 128 x 5 x 5
|
||||||
|
x = self.conv1(x)
|
||||||
|
# N x 768 x 1 x 1
|
||||||
|
# Adaptive average pooling
|
||||||
|
x = F.adaptive_avg_pool2d(x, (1, 1))
|
||||||
|
# N x 768 x 1 x 1
|
||||||
|
x = torch.flatten(x, 1)
|
||||||
|
# N x 768
|
||||||
|
x = self.fc(x)
|
||||||
|
# N x 1000
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BasicConv2d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, **kwargs):
|
||||||
|
super(BasicConv2d, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
||||||
|
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return F.relu(x, inplace=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _inception_v3(variant, pretrained=False, **kwargs):
|
||||||
|
default_cfg = default_cfgs[variant]
|
||||||
|
if kwargs.pop('features_only', False):
|
||||||
|
assert False, 'Not Implemented' # TODO
|
||||||
|
load_strict = False
|
||||||
|
model_kwargs.pop('num_classes', 0)
|
||||||
|
model_class = InceptionV3
|
||||||
|
else:
|
||||||
|
aux_logits = kwargs.pop('aux_logits', False)
|
||||||
|
if aux_logits:
|
||||||
|
model_class = InceptionV3Aux
|
||||||
|
load_strict = default_cfg['has_aux']
|
||||||
|
else:
|
||||||
|
model_class = InceptionV3
|
||||||
|
load_strict = not default_cfg['has_aux']
|
||||||
|
|
||||||
|
model = model_class(**kwargs)
|
||||||
|
model.default_cfg = default_cfg
|
||||||
|
if pretrained:
|
||||||
|
load_pretrained(
|
||||||
|
model,
|
||||||
|
num_classes=kwargs.get('num_classes', 0),
|
||||||
|
in_chans=kwargs.get('in_chans', 3),
|
||||||
|
strict=load_strict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def inception_v3(pretrained=False, **kwargs):
|
||||||
# original PyTorch weights, ported from Tensorflow but modified
|
# original PyTorch weights, ported from Tensorflow but modified
|
||||||
default_cfg = default_cfgs['inception_v3']
|
model = _inception_v3('inception_v3', pretrained=pretrained, **kwargs)
|
||||||
assert in_chans == 3
|
|
||||||
_assert_default_kwargs(kwargs)
|
|
||||||
model = Inception3(num_classes=num_classes, aux_logits=True, transform_input=False)
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def tf_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def tf_inception_v3(pretrained=False, **kwargs):
|
||||||
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
||||||
default_cfg = default_cfgs['tf_inception_v3']
|
model = _inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs)
|
||||||
assert in_chans == 3
|
|
||||||
_assert_default_kwargs(kwargs)
|
|
||||||
model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False)
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def adv_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def adv_inception_v3(pretrained=False, **kwargs):
|
||||||
# my port of Tensorflow adversarially trained Inception V3 from
|
# my port of Tensorflow adversarially trained Inception V3 from
|
||||||
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
||||||
default_cfg = default_cfgs['adv_inception_v3']
|
model = _inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs)
|
||||||
assert in_chans == 3
|
|
||||||
_assert_default_kwargs(kwargs)
|
|
||||||
model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False)
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def gluon_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def gluon_inception_v3(pretrained=False, **kwargs):
|
||||||
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
||||||
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
||||||
default_cfg = default_cfgs['gluon_inception_v3']
|
model = _inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs)
|
||||||
assert in_chans == 3
|
|
||||||
_assert_default_kwargs(kwargs)
|
|
||||||
model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False)
|
|
||||||
if pretrained:
|
|
||||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
||||||
model.default_cfg = default_cfg
|
|
||||||
return model
|
return model
|
||||||
|
@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||||
|
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||||
|
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||||
|
def norm_cdf(x):
|
||||||
|
# Computes standard normal cumulative distribution function
|
||||||
|
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||||
|
|
||||||
|
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||||
|
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||||
|
"The distribution of values may be incorrect.",
|
||||||
|
stacklevel=2)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Values are generated by using a truncated uniform distribution and
|
||||||
|
# then using the inverse CDF for the normal distribution.
|
||||||
|
# Get upper and lower cdf values
|
||||||
|
l = norm_cdf((a - mean) / std)
|
||||||
|
u = norm_cdf((b - mean) / std)
|
||||||
|
|
||||||
|
# Uniformly fill tensor with values from [l, u], then translate to
|
||||||
|
# [2l-1, 2u-1].
|
||||||
|
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||||
|
|
||||||
|
# Use inverse cdf transform for normal distribution to get truncated
|
||||||
|
# standard normal
|
||||||
|
tensor.erfinv_()
|
||||||
|
|
||||||
|
# Transform to proper mean, std
|
||||||
|
tensor.mul_(std * math.sqrt(2.))
|
||||||
|
tensor.add_(mean)
|
||||||
|
|
||||||
|
# Clamp to ensure it's in the proper range
|
||||||
|
tensor.clamp_(min=a, max=b)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||||
|
# type: (Tensor, float, float, float, float) -> Tensor
|
||||||
|
r"""Fills the input Tensor with values drawn from a truncated
|
||||||
|
normal distribution. The values are effectively drawn from the
|
||||||
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||||
|
with values outside :math:`[a, b]` redrawn until they are within
|
||||||
|
the bounds. The method used for generating the random values works
|
||||||
|
best when :math:`a \leq \text{mean} \leq b`.
|
||||||
|
Args:
|
||||||
|
tensor: an n-dimensional `torch.Tensor`
|
||||||
|
mean: the mean of the normal distribution
|
||||||
|
std: the standard deviation of the normal distribution
|
||||||
|
a: the minimum cutoff value
|
||||||
|
b: the maximum cutoff value
|
||||||
|
Examples:
|
||||||
|
>>> w = torch.empty(3, 5)
|
||||||
|
>>> nn.init.trunc_normal_(w)
|
||||||
|
"""
|
||||||
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
Loading…
Reference in new issue