ResNet / Res2Net additions:

* ResNet torchscript compat
* output_stride arg supported to limit network stride via dilations (support for dilation added to Res2Net)
* allow activation layer to be changed via act_layer arg
pull/82/head
Ross Wightman 5 years ago
parent f96b3e5e92
commit 53001dd292

@ -54,9 +54,8 @@ class Bottle2neck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=26, scale=4, use_se=False,
norm_layer=None, dilation=1, previous_dilation=1, **_):
act_layer=nn.ReLU, norm_layer=None, dilation=1, previous_dilation=1, **_):
super(Bottle2neck, self).__init__()
assert dilation == 1 and previous_dilation == 1 # FIXME support dilation
self.scale = scale
self.is_first = stride > 1 or downsample is not None
self.num_scales = max(1, scale - 1)
@ -71,18 +70,20 @@ class Bottle2neck(nn.Module):
bns = []
for i in range(self.num_scales):
convs.append(nn.Conv2d(
width, width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False))
width, width, kernel_size=3, stride=stride, padding=dilation,
dilation=dilation, groups=cardinality, bias=False))
bns.append(norm_layer(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
if self.is_first:
# FIXME this should probably have count_include_pad=False, but hurts original weights
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.relu = nn.ReLU(inplace=True)
self.relu = act_layer(inplace=True)
self.downsample = downsample
def forward(self, x):

@ -125,11 +125,12 @@ class SEModule(nn.Module):
class BasicBlock(nn.Module):
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, use_se=False,
reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d):
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
@ -141,12 +142,13 @@ class BasicBlock(nn.Module):
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
dilation=dilation, bias=False)
self.bn1 = norm_layer(first_planes)
self.relu = nn.ReLU(inplace=True)
self.act1 = act_layer(inplace=True)
self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
dilation=previous_dilation, bias=False)
self.bn2 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.act2 = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
@ -156,7 +158,7 @@ class BasicBlock(nn.Module):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
@ -167,17 +169,18 @@ class BasicBlock(nn.Module):
residual = self.downsample(x)
out += residual
out = self.relu(out)
out = self.act2(out)
return out
class Bottleneck(nn.Module):
__constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, use_se=False,
reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d):
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality)
@ -186,14 +189,16 @@ class Bottleneck(nn.Module):
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True)
self.conv2 = nn.Conv2d(
first_planes, width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width)
self.act2 = act_layer(inplace=True)
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.relu = nn.ReLU(inplace=True)
self.act3 = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
@ -203,11 +208,11 @@ class Bottleneck(nn.Module):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.act2(out)
out = self.conv3(out)
out = self.bn3(out)
@ -219,7 +224,7 @@ class Bottleneck(nn.Module):
residual = self.downsample(x)
out += residual
out = self.relu(out)
out = self.act3(out)
return out
@ -284,9 +289,10 @@ class ResNet(nn.Module):
Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets
avg_down : bool, default False
Whether to use average pooling for projection skip connection between stages/downsample.
dilated : bool, default False
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
typically used in Semantic Segmentation.
output_stride : int, default 32
Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
act_layer : class, activation layer
norm_layer : class, normalization layer
drop_rate : float, default 0.
Dropout probability before classifier, for training
global_pool : str, default 'avg'
@ -294,8 +300,8 @@ class ResNet(nn.Module):
"""
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
cardinality=1, base_width=64, stem_width=64, stem_type='',
block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False,
norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg',
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg',
zero_init_last_bn=True, block_args=None):
block_args = block_args or dict()
self.num_classes = num_classes
@ -305,9 +311,9 @@ class ResNet(nn.Module):
self.base_width = base_width
self.drop_rate = drop_rate
self.expansion = block.expansion
self.dilated = dilated
super(ResNet, self).__init__()
# Stem
if deep_stem:
stem_chs_1 = stem_chs_2 = stem_width
if 'tiered' in stem_type:
@ -316,25 +322,37 @@ class ResNet(nn.Module):
self.conv1 = nn.Sequential(*[
nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False),
norm_layer(stem_chs_1),
nn.ReLU(inplace=True),
act_layer(inplace=True),
nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False),
norm_layer(stem_chs_2),
nn.ReLU(inplace=True),
act_layer(inplace=True),
nn.Conv2d(stem_chs_2, self.inplanes, 3, stride=1, padding=1, bias=False)])
else:
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.act1 = act_layer(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
stride_3_4 = 1 if self.dilated else 2
dilation_3 = 2 if self.dilated else 1
dilation_4 = 4 if self.dilated else 1
largs = dict(use_se=use_se, reduce_first=block_reduce_first, norm_layer=norm_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args)
self.layer1 = self._make_layer(block, 64, layers[0], stride=1, **largs)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, **largs)
self.layer3 = self._make_layer(block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, **largs)
self.layer4 = self._make_layer(block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, **largs)
# Feature Blocks
channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4
if output_stride == 16:
strides[3] = 1
dilations[3] = 2
elif output_stride == 8:
strides[2:4] = [1, 1]
dilations[2:4] = [2, 4]
else:
assert output_stride == 32
llargs = list(zip(channels, layers, strides, dilations))
lkwargs = dict(
use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args)
self.layer1 = self._make_layer(block, *llargs[0], **lkwargs)
self.layer2 = self._make_layer(block, *llargs[1], **lkwargs)
self.layer3 = self._make_layer(block, *llargs[2], **lkwargs)
self.layer4 = self._make_layer(block, *llargs[3], **lkwargs)
# Head (Pooling and Classifier)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_features = 512 * block.expansion
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -352,7 +370,8 @@ class ResNet(nn.Module):
nn.init.constant_(m.bias, 0.)
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d, **kwargs):
use_se=False, avg_down=False, down_kernel_size=1, **kwargs):
norm_layer = kwargs.get('norm_layer')
downsample = None
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
if stride != 1 or self.inplanes != planes * block.expansion:
@ -370,15 +389,15 @@ class ResNet(nn.Module):
downsample = nn.Sequential(*downsample_layers)
first_dilation = 1 if dilation in (1, 2) else 2
bargs = dict(
bkwargs = dict(
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
use_se=use_se, norm_layer=norm_layer, **kwargs)
use_se=use_se, **kwargs)
layers = [block(
self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bargs)]
self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(
self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bargs))
self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bkwargs))
return nn.Sequential(*layers)
@ -394,7 +413,7 @@ class ResNet(nn.Module):
def forward_features(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.act1(x)
x = self.maxpool(x)
x = self.layer1(x)

Loading…
Cancel
Save