Significant ResNet refactor:

* stage creation + make_layer moved to separate fn with more sensible dilation/output_stride calc
* drop path rate decay easy to impl with refactored block creation loops
* fix dilation + blur pool combo
pull/175/head
Ross Wightman 4 years ago
parent a66df5fb91
commit f122f0274b

@ -156,7 +156,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
return output
class DropPath(nn.ModuleDict):
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):

@ -205,14 +205,14 @@ class BasicBlock(nn.Module):
first_planes = planes // reduce_first
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
use_aa = aa_layer is not None
use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
dilation=first_dilation, bias=False)
self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True)
self.aa = aa_layer(channels=first_planes) if stride == 2 and use_aa else None
self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None
self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
@ -272,7 +272,7 @@ class Bottleneck(nn.Module):
first_planes = width // reduce_first
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
use_aa = aa_layer is not None
use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(first_planes)
@ -283,7 +283,7 @@ class Bottleneck(nn.Module):
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width)
self.act2 = act_layer(inplace=True)
self.aa = aa_layer(channels=width) if stride == 2 and use_aa else None
self.aa = aa_layer(channels=width, stride=stride) if use_aa else None
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
@ -336,14 +336,6 @@ class Bottleneck(nn.Module):
return x
def setup_drop_block(drop_block_rate=0.):
return [
None,
None,
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
def downsample_conv(
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
norm_layer = norm_layer or nn.BatchNorm2d
@ -375,6 +367,57 @@ def downsample_avg(
])
def drop_blocks(drop_block_rate=0.):
return [
None, None,
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
def make_blocks(
block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32,
down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs):
stages = []
feature_info = []
net_num_blocks = sum(block_repeats)
net_block_idx = 0
net_stride = 4
dilation = prev_dilation = 1
for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))):
stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
stride = 1 if stage_idx == 0 else 2
if net_stride >= output_stride:
dilation *= stride
stride = 1
else:
net_stride *= stride
downsample = None
if stride != 1 or inplanes != planes * block_fn.expansion:
down_kwargs = dict(
in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size,
stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer'))
downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs)
blocks = []
for block_idx in range(num_blocks):
downsample = downsample if block_idx == 0 else None
stride = stride if block_idx == 0 else 1
block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
blocks.append(block_fn(
inplanes, planes, stride, downsample, first_dilation=prev_dilation,
drop_path=DropPath(block_dpr) if block_dpr > 0. else None, **block_kwargs))
prev_dilation = dilation
inplanes = planes * block_fn.expansion
net_block_idx += 1
stages.append((stage_name, nn.Sequential(*blocks)))
feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))
return stages, feature_info
class ResNet(nn.Module):
"""ResNet / ResNeXt / SE-ResNeXt / SE-Net
@ -448,21 +491,18 @@ class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, stem_width=64, stem_type='',
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
block_args = block_args or dict()
assert output_stride in (8, 16, 32)
self.num_classes = num_classes
deep_stem = 'deep' in stem_type
self.inplanes = stem_width * 2 if deep_stem else 64
self.cardinality = cardinality
self.base_width = base_width
self.drop_rate = drop_rate
self.expansion = block.expansion
super(ResNet, self).__init__()
# Stem
deep_stem = 'deep' in stem_type
inplanes = stem_width * 2 if deep_stem else 64
if deep_stem:
stem_chs_1 = stem_chs_2 = stem_width
if 'tiered' in stem_type:
@ -475,43 +515,31 @@ class ResNet(nn.Module):
nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False),
norm_layer(stem_chs_2),
act_layer(inplace=True),
nn.Conv2d(stem_chs_2, self.inplanes, 3, stride=1, padding=1, bias=False)])
nn.Conv2d(stem_chs_2, inplanes, 3, stride=1, padding=1, bias=False)])
else:
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(inplanes)
self.act1 = act_layer(inplace=True)
self.feature_info = [dict(num_chs=self.inplanes, reduction=2, module='act1')]
self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
# Stem Pooling
if aa_layer is not None:
self.maxpool = nn.Sequential(*[
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
aa_layer(channels=self.inplanes, stride=2)
])
aa_layer(channels=inplanes, stride=2)])
else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Feature Blocks
channels = [64, 128, 256, 512]
dp = DropPath(drop_path_rate) if drop_path_rate else None
db = setup_drop_block(drop_block_rate)
layer_kwargs = dict(
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args)
total_stride = 4
dilation = 1
for i in range(4):
layer_name = f'layer{i + 1}'
stride = 2 if i > 0 else 1
if total_stride >= output_stride:
dilation *= stride
stride = 1
else:
total_stride *= stride
self.add_module(layer_name, self._make_layer(
block, channels[i], layers[i], stride, dilation, drop_block=db[i], **layer_kwargs))
self.feature_info.append(dict(
num_chs=self.inplanes, reduction=total_stride, module=layer_name))
stage_modules, stage_feature_info = make_blocks(
block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width,
output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down,
down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args)
for stage in stage_modules:
self.add_module(*stage) # layer1, layer2, etc
self.feature_info.extend(stage_feature_info)
# Head (Pooling and Classifier)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -529,25 +557,6 @@ class ResNet(nn.Module):
if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
avg_down=False, down_kernel_size=1, **kwargs):
downsample = None
first_dilation = 1 if dilation in (1, 2) else 2
if stride != 1 or self.inplanes != planes * block.expansion:
downsample_args = dict(
in_channels=self.inplanes, out_channels=planes * block.expansion, kernel_size=down_kernel_size,
stride=stride, dilation=dilation, first_dilation=first_dilation, norm_layer=kwargs.get('norm_layer'))
downsample = downsample_avg(**downsample_args) if avg_down else downsample_conv(**downsample_args)
block_kwargs = dict(
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
dilation=dilation, **kwargs)
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
self.inplanes = planes * block.expansion
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
return nn.Sequential(*layers)
def get_classifier(self):
return self.fc

Loading…
Cancel
Save