Significant ResNet refactor:

* stage creation + make_layer moved to separate fn with more sensible dilation/output_stride calc
* drop path rate decay easy to impl with refactored block creation loops
* fix dilation + blur pool combo
pull/175/head
Ross Wightman 4 years ago
parent a66df5fb91
commit f122f0274b

@ -156,7 +156,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
return output return output
class DropPath(nn.ModuleDict): class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
""" """
def __init__(self, drop_prob=None): def __init__(self, drop_prob=None):

@ -205,14 +205,14 @@ class BasicBlock(nn.Module):
first_planes = planes // reduce_first first_planes = planes // reduce_first
outplanes = planes * self.expansion outplanes = planes * self.expansion
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
use_aa = aa_layer is not None use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
dilation=first_dilation, bias=False) dilation=first_dilation, bias=False)
self.bn1 = norm_layer(first_planes) self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.aa = aa_layer(channels=first_planes) if stride == 2 and use_aa else None self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
@ -272,7 +272,7 @@ class Bottleneck(nn.Module):
first_planes = width // reduce_first first_planes = width // reduce_first
outplanes = planes * self.expansion outplanes = planes * self.expansion
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
use_aa = aa_layer is not None use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(first_planes) self.bn1 = norm_layer(first_planes)
@ -283,7 +283,7 @@ class Bottleneck(nn.Module):
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width) self.bn2 = norm_layer(width)
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
self.aa = aa_layer(channels=width) if stride == 2 and use_aa else None self.aa = aa_layer(channels=width, stride=stride) if use_aa else None
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes) self.bn3 = norm_layer(outplanes)
@ -336,14 +336,6 @@ class Bottleneck(nn.Module):
return x return x
def setup_drop_block(drop_block_rate=0.):
return [
None,
None,
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
def downsample_conv( def downsample_conv(
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
@ -375,6 +367,57 @@ def downsample_avg(
]) ])
def drop_blocks(drop_block_rate=0.):
return [
None, None,
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
def make_blocks(
block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32,
down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs):
stages = []
feature_info = []
net_num_blocks = sum(block_repeats)
net_block_idx = 0
net_stride = 4
dilation = prev_dilation = 1
for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))):
stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
stride = 1 if stage_idx == 0 else 2
if net_stride >= output_stride:
dilation *= stride
stride = 1
else:
net_stride *= stride
downsample = None
if stride != 1 or inplanes != planes * block_fn.expansion:
down_kwargs = dict(
in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size,
stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer'))
downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs)
blocks = []
for block_idx in range(num_blocks):
downsample = downsample if block_idx == 0 else None
stride = stride if block_idx == 0 else 1
block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
blocks.append(block_fn(
inplanes, planes, stride, downsample, first_dilation=prev_dilation,
drop_path=DropPath(block_dpr) if block_dpr > 0. else None, **block_kwargs))
prev_dilation = dilation
inplanes = planes * block_fn.expansion
net_block_idx += 1
stages.append((stage_name, nn.Sequential(*blocks)))
feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))
return stages, feature_info
class ResNet(nn.Module): class ResNet(nn.Module):
"""ResNet / ResNeXt / SE-ResNeXt / SE-Net """ResNet / ResNeXt / SE-ResNeXt / SE-Net
@ -448,21 +491,18 @@ class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, in_chans=3, def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, stem_width=64, stem_type='', cardinality=1, base_width=64, stem_width=64, stem_type='',
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
block_args = block_args or dict() block_args = block_args or dict()
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
self.num_classes = num_classes self.num_classes = num_classes
deep_stem = 'deep' in stem_type
self.inplanes = stem_width * 2 if deep_stem else 64
self.cardinality = cardinality
self.base_width = base_width
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.expansion = block.expansion
super(ResNet, self).__init__() super(ResNet, self).__init__()
# Stem # Stem
deep_stem = 'deep' in stem_type
inplanes = stem_width * 2 if deep_stem else 64
if deep_stem: if deep_stem:
stem_chs_1 = stem_chs_2 = stem_width stem_chs_1 = stem_chs_2 = stem_width
if 'tiered' in stem_type: if 'tiered' in stem_type:
@ -475,43 +515,31 @@ class ResNet(nn.Module):
nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False), nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False),
norm_layer(stem_chs_2), norm_layer(stem_chs_2),
act_layer(inplace=True), act_layer(inplace=True),
nn.Conv2d(stem_chs_2, self.inplanes, 3, stride=1, padding=1, bias=False)]) nn.Conv2d(stem_chs_2, inplanes, 3, stride=1, padding=1, bias=False)])
else: else:
self.conv1 = nn.Conv2d(in_chans, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes) self.bn1 = norm_layer(inplanes)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.feature_info = [dict(num_chs=self.inplanes, reduction=2, module='act1')] self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
# Stem Pooling # Stem Pooling
if aa_layer is not None: if aa_layer is not None:
self.maxpool = nn.Sequential(*[ self.maxpool = nn.Sequential(*[
nn.MaxPool2d(kernel_size=3, stride=1, padding=1), nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
aa_layer(channels=self.inplanes, stride=2) aa_layer(channels=inplanes, stride=2)])
])
else: else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Feature Blocks # Feature Blocks
channels = [64, 128, 256, 512] channels = [64, 128, 256, 512]
dp = DropPath(drop_path_rate) if drop_path_rate else None stage_modules, stage_feature_info = make_blocks(
db = setup_drop_block(drop_block_rate) block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width,
layer_kwargs = dict( output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down,
reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args)
total_stride = 4 for stage in stage_modules:
dilation = 1 self.add_module(*stage) # layer1, layer2, etc
for i in range(4): self.feature_info.extend(stage_feature_info)
layer_name = f'layer{i + 1}'
stride = 2 if i > 0 else 1
if total_stride >= output_stride:
dilation *= stride
stride = 1
else:
total_stride *= stride
self.add_module(layer_name, self._make_layer(
block, channels[i], layers[i], stride, dilation, drop_block=db[i], **layer_kwargs))
self.feature_info.append(dict(
num_chs=self.inplanes, reduction=total_stride, module=layer_name))
# Head (Pooling and Classifier) # Head (Pooling and Classifier)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@ -529,25 +557,6 @@ class ResNet(nn.Module):
if hasattr(m, 'zero_init_last_bn'): if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn() m.zero_init_last_bn()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
avg_down=False, down_kernel_size=1, **kwargs):
downsample = None
first_dilation = 1 if dilation in (1, 2) else 2
if stride != 1 or self.inplanes != planes * block.expansion:
downsample_args = dict(
in_channels=self.inplanes, out_channels=planes * block.expansion, kernel_size=down_kernel_size,
stride=stride, dilation=dilation, first_dilation=first_dilation, norm_layer=kwargs.get('norm_layer'))
downsample = downsample_avg(**downsample_args) if avg_down else downsample_conv(**downsample_args)
block_kwargs = dict(
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
dilation=dilation, **kwargs)
layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)]
self.inplanes = planes * block.expansion
layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)]
return nn.Sequential(*layers)
def get_classifier(self): def get_classifier(self):
return self.fc return self.fc

Loading…
Cancel
Save