From e78cd790739506fb41d34f3c147ef3d8099f15f1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 14 Jul 2019 18:17:35 -0700 Subject: [PATCH] Move ResNet additions for Gluon into main ResNet impl. Add ResNet-26 and ResNet-26d models with weights. --- timm/models/resnet.py | 293 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 239 insertions(+), 54 deletions(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 32ff3acf..eb40f16b 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -1,11 +1,13 @@ -"""Pytorch ResNet implementation w/ tweaks -This file is a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with +"""PyTorch ResNet + +This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with additional dropout and dynamic global avg/max pool. -ResNext additions added by Ross Wightman +ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants added by Ross Wightman """ import math +import torch import torch.nn as nn import torch.nn.functional as F @@ -33,6 +35,12 @@ default_cfgs = { 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), 'resnet34': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'), + 'resnet26': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth', + interpolation='bicubic'), + 'resnet26d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth', + interpolation='bicubic'), 'resnet50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth', interpolation='bicubic'), @@ -45,6 +53,7 @@ default_cfgs = { 'resnext50_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d-068914d1.pth', interpolation='bicubic'), + 'resnext50d_32x4d': _cfg(url=''), 'resnext101_32x4d': _cfg(url=''), 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'), 'resnext101_64x4d': _cfg(url=''), @@ -56,30 +65,60 @@ default_cfgs = { } -def conv3x3(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d( - in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) +def _get_padding(kernel_size, stride, dilation=1): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction_channels): + super(SEModule, self).__init__() + #self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d( + channels, reduction_channels, kernel_size=1, padding=0, bias=True) + self.relu = nn.ReLU() + self.fc2 = nn.Conv2d( + reduction_channels, channels, kernel_size=1, padding=0, bias=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + #x = self.avg_pool(x) + x = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, drop_rate=0.0): + cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' - - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = nn.BatchNorm2d(planes) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = nn.BatchNorm2d(planes) + first_planes = planes // reduce_first + outplanes = planes * self.expansion + + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, + dilation=dilation, bias=False) + self.bn1 = norm_layer(first_planes) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d( + first_planes, outplanes, kernel_size=3, padding=previous_dilation, + dilation=previous_dilation, bias=False) + self.bn2 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None self.downsample = downsample self.stride = stride - self.drop_rate = drop_rate + self.dilation = dilation def forward(self, x): residual = x @@ -87,13 +126,12 @@ class BasicBlock(nn.Module): out = self.conv1(x) out = self.bn1(out) out = self.relu(out) - - if self.drop_rate > 0.: - out = F.dropout(out, p=self.drop_rate, training=self.training) - out = self.conv2(out) out = self.bn2(out) + if self.se is not None: + out = self.se(out) + if self.downsample is not None: residual = self.downsample(x) @@ -107,22 +145,27 @@ class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, drop_rate=0.0): + cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) - - self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(width) - self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, - padding=1, groups=cardinality, bias=False) - self.bn2 = nn.BatchNorm2d(width) - self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) - self.relu = nn.ReLU(inplace=True) + first_planes = width // reduce_first + outplanes = planes * self.expansion + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(width) + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.relu = nn.ReLU() self.downsample = downsample self.stride = stride - self.drop_rate = drop_rate + self.dilation = dilation def forward(self, x): residual = x @@ -131,9 +174,6 @@ class Bottleneck(nn.Module): out = self.bn1(out) out = self.relu(out) - if self.drop_rate > 0.: - out = F.dropout(out, p=self.drop_rate, training=self.training) - out = self.conv2(out) out = self.bn2(out) out = self.relu(out) @@ -141,6 +181,9 @@ class Bottleneck(nn.Module): out = self.conv3(out) out = self.bn3(out) + if self.se is not None: + out = self.se(out) + if self.downsample is not None: residual = self.downsample(x) @@ -151,26 +194,110 @@ class Bottleneck(nn.Module): class ResNet(nn.Module): - - def __init__(self, block, layers, num_classes=1000, in_chans=3, - cardinality=1, base_width=64, - drop_rate=0.0, block_drop_rate=0.0, - global_pool='avg'): + """ResNet / ResNeXt / SE-ResNeXt / SE-Net + + This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that + * have > 1 stride in the 3x3 conv layer of bottleneck + * have conv-bn-act ordering + + This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s + variants included in the MXNet Gluon ResNetV1b model + + ResNet variants: + * normal - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b + * c - 3 layer deep 3x3 stem, stem_width = 32 + * d - 3 layer deep 3x3 stem, stem_width = 32, average pool in downsample + * e - 3 layer deep 3x3 stem, stem_width = 64, average pool in downsample *no pretrained weights available + * s - 3 layer deep 3x3 stem, stem_width = 64 + + ResNeXt + * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths + * same c,d, e, s variants as ResNet can be enabled + + SE-ResNeXt + * normal - 7x7 stem, stem_width = 64 + * same c, d, e, s variants as ResNet can be enabled + + SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64, + reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block + + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockGl, BottleneckGl. + layers : list of int + Numbers of layers in each block + num_classes : int, default 1000 + Number of classification classes. + in_chans : int, default 3 + Number of input (color) channels. + use_se : bool, default False + Enable Squeeze-Excitation module in blocks + cardinality : int, default 1 + Number of convolution groups for 3x3 conv in Bottleneck. + base_width : int, default 64 + Factor determining bottleneck channels. `planes * base_width / 64 * cardinality` + deep_stem : bool, default False + Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. + stem_width : int, default 64 + Number of channels in stem convolutions + block_reduce_first: int, default 1 + Reduction factor for first convolution output width of residual blocks, + 1 for all archs except senets, where 2 + down_kernel_size: int, default 1 + Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets + avg_down : bool, default False + Whether to use average pooling for projection skip connection between stages/downsample. + dilated : bool, default False + Applying dilation strategy to pretrained ResNet yielding a stride-8 model, + typically used in Semantic Segmentation. + drop_rate : float, default 0. + Dropout probability before classifier, for training + global_pool : str, default 'avg' + Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + """ + def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, + cardinality=1, base_width=64, stem_width=64, deep_stem=False, + block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False, + norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg'): self.num_classes = num_classes - self.inplanes = 64 + self.inplanes = stem_width * 2 if deep_stem else 64 self.cardinality = cardinality self.base_width = base_width self.drop_rate = drop_rate self.expansion = block.expansion + self.dilated = dilated super(ResNet, self).__init__() - self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = nn.BatchNorm2d(64) - self.relu = nn.ReLU(inplace=True) + + if deep_stem: + self.conv1 = nn.Sequential(*[ + nn.Conv2d(in_chans, stem_width, 3, stride=2, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(), + nn.Conv2d(stem_width, stem_width, 3, stride=1, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(), + nn.Conv2d(stem_width, self.inplanes, 3, stride=1, padding=1, bias=False)]) + else: + self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU() self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0], drop_rate=block_drop_rate) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, drop_rate=block_drop_rate) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, drop_rate=block_drop_rate) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, drop_rate=block_drop_rate) + stride_3_4 = 1 if self.dilated else 2 + dilation_3 = 2 if self.dilated else 1 + dilation_4 = 4 if self.dilated else 1 + self.layer1 = self._make_layer( + block, 64, layers[0], stride=1, reduce_first=block_reduce_first, + use_se=use_se, avg_down=avg_down, down_kernel_size=1, norm_layer=norm_layer) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, reduce_first=block_reduce_first, + use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, reduce_first=block_reduce_first, + use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, reduce_first=block_reduce_first, + use_se=use_se, avg_down=avg_down, down_kernel_size=down_kernel_size, norm_layer=norm_layer) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_features = 512 * block.expansion self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -182,18 +309,34 @@ class ResNet(nn.Module): nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) - def _make_layer(self, block, planes, blocks, stride=1, drop_rate=0.): + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, + use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion), - ) - - layers = [block(self.inplanes, planes, stride, downsample, self.cardinality, self.base_width, drop_rate)] + downsample_padding = _get_padding(down_kernel_size, stride) + downsample_layers = [] + conv_stride = stride + if avg_down: + avg_stride = stride if dilation == 1 else 1 + conv_stride = 1 + downsample_layers = [nn.AvgPool2d(avg_stride, avg_stride, ceil_mode=True, count_include_pad=False)] + downsample_layers += [ + nn.Conv2d(self.inplanes, planes * block.expansion, down_kernel_size, + stride=conv_stride, padding=downsample_padding, bias=False), + norm_layer(planes * block.expansion)] + downsample = nn.Sequential(*downsample_layers) + + first_dilation = 1 if dilation in (1, 2) else 2 + layers = [block( + self.inplanes, planes, stride, downsample, + cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, + use_se=use_se, dilation=first_dilation, previous_dilation=dilation, norm_layer=norm_layer)] self.inplanes = planes * block.expansion for i in range(1, blocks): - layers.append(block(self.inplanes, planes, cardinality=self.cardinality, base_width=self.base_width)) + layers.append(block( + self.inplanes, planes, + cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, + use_se=use_se, dilation=dilation, previous_dilation=dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) @@ -257,6 +400,33 @@ def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-26 model. + """ + default_cfg = default_cfgs['resnet26'] + model = ResNet(Bottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-26 v1d model. + This is technically a 28 layer ResNet, sticking with 'd' modifier from Gluon for now. + """ + default_cfg = default_cfgs['resnet26d'] + model = ResNet( + Bottleneck, [2, 2, 2, 2], stem_width=32, deep_stem=True, avg_down=True, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-50 model. @@ -362,6 +532,21 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def resnext50d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNeXt50-32x4d model. + """ + default_cfg = default_cfgs['resnext50d_32x4d'] + model = ResNet( + Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, + stem_width=32, deep_stem=True, avg_down=True, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt-101 32x4d model.