You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/models/senet.py

467 lines
17 KiB

"""
ResNet code gently borrowed from
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
"""
from __future__ import print_function, division, absolute_import
from collections import OrderedDict
import math
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
__all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
'seresnext50_32x4d', 'seresnext101_32x4d']
model_urls = {
'senet154': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
'seresnet18': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'seresnet34': 'https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1',
'seresnet50': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'seresnet101': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
'seresnet152': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
'seresnext50_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
'seresnext101_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
}
def _weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
class SEModule(nn.Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
#self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(
channels, channels // reduction, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(
channels // reduction, channels, kernel_size=1, padding=0)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
module_input = x
#x = self.avg_pool(x)
x = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class Bottleneck(nn.Module):
"""
Base class for bottlenecks that implements `forward()` method.
"""
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out = self.se_module(out) + residual
out = self.relu(out)
return out
class SEBottleneck(Bottleneck):
"""
Bottleneck for SENet154.
"""
expansion = 4
def __init__(self, inplanes, planes, groups, reduction, stride=1,
downsample=None):
super(SEBottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes * 2)
self.conv2 = nn.Conv2d(
planes * 2, planes * 4, kernel_size=3, stride=stride,
padding=1, groups=groups, bias=False)
self.bn2 = nn.BatchNorm2d(planes * 4)
self.conv3 = nn.Conv2d(
planes * 4, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.se_module = SEModule(planes * 4, reduction=reduction)
self.downsample = downsample
self.stride = stride
class SEResNetBottleneck(Bottleneck):
"""
ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
implementation and uses `stride=stride` in `conv1` and not in `conv2`
(the latter is used in the torchvision implementation of ResNet).
"""
expansion = 4
def __init__(self, inplanes, planes, groups, reduction, stride=1,
downsample=None):
super(SEResNetBottleneck, self).__init__()
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=1, bias=False, stride=stride)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, groups=groups, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.se_module = SEModule(planes * 4, reduction=reduction)
self.downsample = downsample
self.stride = stride
class SEResNeXtBottleneck(Bottleneck):
"""
ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
"""
expansion = 4
def __init__(self, inplanes, planes, groups, reduction, stride=1,
downsample=None, base_width=4):
super(SEResNeXtBottleneck, self).__init__()
width = math.floor(planes * (base_width / 64)) * groups
self.conv1 = nn.Conv2d(
inplanes, width, kernel_size=1, bias=False, stride=1)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(
width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.se_module = SEModule(planes * 4, reduction=reduction)
self.downsample = downsample
self.stride = stride
class SEResNetBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None):
super(SEResNetBlock, self).__init__()
self.conv1 = nn.Conv2d(
inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, groups=groups, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.se_module = SEModule(planes, reduction=reduction)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
if self.downsample is not None:
residual = self.downsample(x)
out = self.se_module(out) + residual
out = self.relu(out)
return out
class SENet(nn.Module):
def __init__(self, block, layers, groups, reduction, drop_rate=0.2,
inchans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3,
downsample_padding=1, num_classes=1000, global_pool='avg'):
"""
Parameters
----------
block (nn.Module): Bottleneck class.
- For SENet154: SEBottleneck
- For SE-ResNet models: SEResNetBottleneck
- For SE-ResNeXt models: SEResNeXtBottleneck
layers (list of ints): Number of residual blocks for 4 layers of the
network (layer1...layer4).
groups (int): Number of groups for the 3x3 convolution in each
bottleneck block.
- For SENet154: 64
- For SE-ResNet models: 1
- For SE-ResNeXt models: 32
reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
- For all models: 16
dropout_p (float or None): Drop probability for the Dropout layer.
If `None` the Dropout layer is not used.
- For SENet154: 0.2
- For SE-ResNet models: None
- For SE-ResNeXt models: None
inplanes (int): Number of input channels for layer1.
- For SENet154: 128
- For SE-ResNet models: 64
- For SE-ResNeXt models: 64
input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
a single 7x7 convolution in layer0.
- For SENet154: True
- For SE-ResNet models: False
- For SE-ResNeXt models: False
downsample_kernel_size (int): Kernel size for downsampling convolutions
in layer2, layer3 and layer4.
- For SENet154: 3
- For SE-ResNet models: 1
- For SE-ResNeXt models: 1
downsample_padding (int): Padding for downsampling convolutions in
layer2, layer3 and layer4.
- For SENet154: 1
- For SE-ResNet models: 0
- For SE-ResNeXt models: 0
num_classes (int): Number of outputs in `last_linear` layer.
- For all models: 1000
"""
super(SENet, self).__init__()
self.inplanes = inplanes
self.num_classes = num_classes
if input_3x3:
layer0_modules = [
('conv1', nn.Conv2d(inchans, 64, 3, stride=2, padding=1, bias=False)),
('bn1', nn.BatchNorm2d(64)),
('relu1', nn.ReLU(inplace=True)),
('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
('bn2', nn.BatchNorm2d(64)),
('relu2', nn.ReLU(inplace=True)),
('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)),
('bn3', nn.BatchNorm2d(inplanes)),
('relu3', nn.ReLU(inplace=True)),
]
else:
layer0_modules = [
('conv1', nn.Conv2d(
inchans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
('bn1', nn.BatchNorm2d(inplanes)),
('relu1', nn.ReLU(inplace=True)),
]
# To preserve compatibility with Caffe weights `ceil_mode=True`
# is used instead of `padding=1`.
layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, ceil_mode=True)))
self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
self.layer1 = self._make_layer(
block,
planes=64,
blocks=layers[0],
groups=groups,
reduction=reduction,
downsample_kernel_size=1,
downsample_padding=0
)
self.layer2 = self._make_layer(
block,
planes=128,
blocks=layers[1],
stride=2,
groups=groups,
reduction=reduction,
downsample_kernel_size=downsample_kernel_size,
downsample_padding=downsample_padding
)
self.layer3 = self._make_layer(
block,
planes=256,
blocks=layers[2],
stride=2,
groups=groups,
reduction=reduction,
downsample_kernel_size=downsample_kernel_size,
downsample_padding=downsample_padding
)
self.layer4 = self._make_layer(
block,
planes=512,
blocks=layers[3],
stride=2,
groups=groups,
reduction=reduction,
downsample_kernel_size=downsample_kernel_size,
downsample_padding=downsample_padding
)
self.avg_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.drop_rate = drop_rate
self.num_features = 512 * block.expansion
self.last_linear = nn.Linear(self.num_features, num_classes)
for m in self.modules():
_weight_init(m)
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
downsample_kernel_size=1, downsample_padding=0):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=downsample_kernel_size, stride=stride,
padding=downsample_padding, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(
self.inplanes, planes, groups, reduction, stride, downsample)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, groups, reduction))
return nn.Sequential(*layers)
def get_classifier(self):
return self.last_linear
def reset_classifier(self, num_classes):
self.num_classes = num_classes
del self.last_linear
if num_classes:
self.last_linear = nn.Linear(self.num_features, num_classes)
else:
self.last_linear = None
def forward_features(self, x, pool=True):
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if pool:
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
return x
def logits(self, x):
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.last_linear(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.logits(x)
return x
def _load_pretrained(model, url, inchans=3):
state_dict = model_zoo.load_url(url)
if inchans == 1:
conv1_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
elif inchans != 3:
assert False, "Invalid inchans for pretrained weights"
model.load_state_dict(state_dict)
def senet154(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
num_classes=num_classes, **kwargs)
if pretrained:
_load_pretrained(model, model_urls['senet154'], inchans)
return model
def seresnet18(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
if pretrained:
_load_pretrained(model, model_urls['seresnet18'], inchans)
return model
def seresnet34(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
if pretrained:
_load_pretrained(model, model_urls['seresnet34'], inchans)
return model
def seresnet50(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
if pretrained:
_load_pretrained(model, model_urls['seresnet50'], inchans)
return model
def seresnet101(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
if pretrained:
_load_pretrained(model, model_urls['seresnet101'], inchans)
return model
def seresnet152(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
if pretrained:
_load_pretrained(model, model_urls['seresnet152'], inchans)
return model
def seresnext26_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
model = SENet(SEResNeXtBottleneck, [2, 2, 2, 2], groups=32, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
if pretrained:
_load_pretrained(model, model_urls['se_resnext26_32x4d'], inchans)
return model
def seresnext50_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
if pretrained:
_load_pretrained(model, model_urls['seresnext50_32x4d'], inchans)
return model
def seresnext101_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes, **kwargs)
if pretrained:
_load_pretrained(model, model_urls['seresnext101_32x4d'], inchans)
return model