Move ResNet additions for Gluon into main ResNet impl. Add ResNet-26 and ResNet-26d models with weights.

pull/19/head
Ross Wightman 5 years ago
parent c37cb0f2b7
commit e78cd79073

@ -1,11 +1,13 @@
"""Pytorch ResNet implementation w/ tweaks
This file is a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
"""PyTorch ResNet
This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
additional dropout and dynamic global avg/max pool.
ResNext additions added by Ross Wightman
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants added by Ross Wightman
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -33,6 +35,12 @@ default_cfgs = {
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
'resnet34': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'),
'resnet26': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth',
interpolation='bicubic'),
'resnet26d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth',
interpolation='bicubic'),
'resnet50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth',
interpolation='bicubic'),
@ -45,6 +53,7 @@ default_cfgs = {
'resnext50_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d-068914d1.pth',
interpolation='bicubic'),
'resnext50d_32x4d': _cfg(url=''),
'resnext101_32x4d': _cfg(url=''),
'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'),
'resnext101_64x4d': _cfg(url=''),
@ -56,30 +65,60 @@ default_cfgs = {
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def _get_padding(kernel_size, stride, dilation=1):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
class SEModule(nn.Module):
def __init__(self, channels, reduction_channels):
super(SEModule, self).__init__()
#self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.relu = nn.ReLU()
self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
module_input = x
#x = self.avg_pool(x)
x = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, drop_rate=0.0):
cardinality=1, base_width=64, use_se=False,
reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d):
super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
assert base_width == 64, 'BasicBlock doest not support changing base width'
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
first_planes = planes // reduce_first
outplanes = planes * self.expansion
self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
dilation=dilation, bias=False)
self.bn1 = norm_layer(first_planes)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
dilation=previous_dilation, bias=False)
self.bn2 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.downsample = downsample
self.stride = stride
self.drop_rate = drop_rate
self.dilation = dilation
def forward(self, x):
residual = x
@ -87,13 +126,12 @@ class BasicBlock(nn.Module):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
if self.drop_rate > 0.:
out = F.dropout(out, p=self.drop_rate, training=self.training)
out = self.conv2(out)
out = self.bn2(out)
if self.se is not None:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
@ -107,22 +145,27 @@ class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, drop_rate=0.0):
cardinality=1, base_width=64, use_se=False,
reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d):
super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality)
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
padding=1, groups=cardinality, bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
first_planes = width // reduce_first
outplanes = planes * self.expansion
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(first_planes)
self.conv2 = nn.Conv2d(
first_planes, width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width)
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
self.drop_rate = drop_rate
self.dilation = dilation
def forward(self, x):
residual = x
@ -131,9 +174,6 @@ class Bottleneck(nn.Module):
out = self.bn1(out)
out = self.relu(out)
if self.drop_rate > 0.:
out = F.dropout(out, p=self.drop_rate, training=self.training)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
@ -141,6 +181,9 @@ class Bottleneck(nn.Module):
out = self.conv3(out)
out = self.bn3(out)
if self.se is not None:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
@ -151,26 +194,110 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64,
drop_rate=0.0, block_drop_rate=0.0,
global_pool='avg'):
"""ResNet / ResNeXt / SE-ResNeXt / SE-Net
This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that
* have > 1 stride in the 3x3 conv layer of bottleneck
* have conv-bn-act ordering
This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
variants included in the MXNet Gluon ResNetV1b model
ResNet variants:
* normal - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
* c - 3 layer deep 3x3 stem, stem_width = 32
* d - 3 layer deep 3x3 stem, stem_width = 32, average pool in downsample
* e - 3 layer deep 3x3 stem, stem_width = 64, average pool in downsample *no pretrained weights available
* s - 3 layer deep 3x3 stem, stem_width = 64
ResNeXt
* normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
* same c,d, e, s variants as ResNet can be enabled
SE-ResNeXt
* normal - 7x7 stem, stem_width = 64
* same c, d, e, s variants as ResNet can be enabled
SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,
reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block
Parameters
----------
block : Block
Class for the residual block. Options are BasicBlockGl, BottleneckGl.
layers : list of int
Numbers of layers in each block
num_classes : int, default 1000
Number of classification classes.
in_chans : int, default 3
Number of input (color) channels.
use_se : bool, default False
Enable Squeeze-Excitation module in blocks
cardinality : int, default 1
Number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64
Factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
deep_stem : bool, default False
Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
stem_width : int, default 64
Number of channels in stem convolutions
block_reduce_first: int, default 1
Reduction factor for first convolution output width of residual blocks,
1 for all archs except senets, where 2
down_kernel_size: int, default 1
Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets
avg_down : bool, default False
Whether to use average pooling for projection skip connection between stages/downsample.
dilated : bool, default False
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
typically used in Semantic Segmentation.
drop_rate : float, default 0.
Dropout probability before classifier, for training
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
"""
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
cardinality=1, base_width=64, stem_width=64, deep_stem=False,
block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False,
norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg'):
self.num_classes = num_classes
self.inplanes = 64
self.inplanes = stem_width * 2 if deep_stem else 64
self.cardinality = cardinality
self.base_width = base_width
self.drop_rate = drop_rate
self.expansion = block.expansion
self.dilated = dilated
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
if deep_stem:
self.conv1 = nn.Sequential(*[
nn.Conv2d(in_chans, stem_width, 3, stride=2, padding=1, bias=False),
norm_layer(stem_width),
nn.ReLU(),
nn.Conv2d(stem_width, stem_width, 3, stride=1, padding=1, bias=False),
norm_layer(stem_width),
nn.ReLU(),
nn.Conv2d(stem_width, self.inplanes, 3, stride=1, padding=1, bias=False)])
else:
self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], drop_rate=block_drop_rate)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, drop_rate=block_drop_rate)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, drop_rate=block_drop_rate)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, drop_rate=block_drop_rate)
stride_3_4 = 1 if self.dilated else 2
dilation_3 = 2 if self.dilated else 1
dilation_4 = 4 if self.dilated else 1
self.layer1 = self._make_layer(
block, 64, layers[0], stride=1, reduce_first=block_reduce_first,
use_se=use_se, avg_down=avg_down, down_kernel_size=1, norm_layer=norm_layer)
self.layer2 = self._make_layer(
block, 128, layers[1], stride=2, reduce_first=block_reduce_first,
use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer)
self.layer3 = self._make_layer(
block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, reduce_first=block_reduce_first,
use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer)
self.layer4 = self._make_layer(
block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, reduce_first=block_reduce_first,
use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_features = 512 * block.expansion
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -182,18 +309,34 @@ class ResNet(nn.Module):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
def _make_layer(self, block, planes, blocks, stride=1, drop_rate=0.):
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, downsample, self.cardinality, self.base_width, drop_rate)]
downsample_padding = _get_padding(down_kernel_size, stride)
downsample_layers = []
conv_stride = stride
if avg_down:
avg_stride = stride if dilation == 1 else 1
conv_stride = 1
downsample_layers = [nn.AvgPool2d(avg_stride, avg_stride, ceil_mode=True, count_include_pad=False)]
downsample_layers += [
nn.Conv2d(self.inplanes, planes * block.expansion, down_kernel_size,
stride=conv_stride, padding=downsample_padding, bias=False),
norm_layer(planes * block.expansion)]
downsample = nn.Sequential(*downsample_layers)
first_dilation = 1 if dilation in (1, 2) else 2
layers = [block(
self.inplanes, planes, stride, downsample,
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
use_se=use_se, dilation=first_dilation, previous_dilation=dilation, norm_layer=norm_layer)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, cardinality=self.cardinality, base_width=self.base_width))
layers.append(block(
self.inplanes, planes,
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
use_se=use_se, dilation=dilation, previous_dilation=dilation, norm_layer=norm_layer))
return nn.Sequential(*layers)
@ -257,6 +400,33 @@ def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model
@register_model
def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-26 model.
"""
default_cfg = default_cfgs['resnet26']
model = ResNet(Bottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-26 v1d model.
This is technically a 28 layer ResNet, sticking with 'd' modifier from Gluon for now.
"""
default_cfg = default_cfgs['resnet26d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], stem_width=32, deep_stem=True, avg_down=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-50 model.
@ -362,6 +532,21 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model
@register_model
def resnext50d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt50-32x4d model.
"""
default_cfg = default_cfgs['resnext50d_32x4d']
model = ResNet(
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
stem_width=32, deep_stem=True, avg_down=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt-101 32x4d model.

Loading…
Cancel
Save