Only include resnet-rs changes

pull/554/head
Aman Arora 4 years ago
parent 844993267b
commit b117e16128

@ -1,8 +1,6 @@
"""PyTorch ResNet
This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
additional dropout and dynamic global avg/max pool.
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
Copyright 2020 Ross Wightman
"""
@ -442,7 +440,7 @@ def drop_blocks(drop_block_rate=0.):
def make_blocks(
block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32,
down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., first_conv_stride=1, **kwargs):
down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs):
stages = []
feature_info = []
net_num_blocks = sum(block_repeats)
@ -451,7 +449,7 @@ def make_blocks(
dilation = prev_dilation = 1
for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))):
stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
stride = first_conv_stride if stage_idx == 0 else 2
stride = 1 if stage_idx == 0 else 2
if net_stride >= output_stride:
dilation *= stride
stride = 1
@ -558,12 +556,12 @@ class ResNet(nn.Module):
cardinality=1, base_width=64, stem_width=64, stem_type='',
output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None, skip_stem_max_pool=False):
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None, replace_stem_max_pool=False):
block_args = block_args or dict()
assert output_stride in (8, 16, 32)
self.num_classes = num_classes
self.drop_rate = drop_rate
self.skip_stem_max_pool = skip_stem_max_pool
self.replace_stem_max_pool = replace_stem_max_pool
super(ResNet, self).__init__()
# Stem
@ -588,8 +586,7 @@ class ResNet(nn.Module):
self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
# Stem Pooling
if not self.skip_stem_max_pool:
first_conv_stride = 1
if not self.replace_stem_max_pool:
if aa_layer is not None:
self.maxpool = nn.Sequential(*[
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
@ -597,8 +594,11 @@ class ResNet(nn.Module):
else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
else:
self.maxpool = nn.Identity()
first_conv_stride = 2
self.maxpool = nn.Sequential(*[
nn.Conv2d(inplanes, inplanes, 3, stride=2, padding=1),
nn.BatchNorm2d(inplanes),
nn.ReLU()
])
# Feature Blocks
channels = [64, 128, 256, 512]
@ -606,7 +606,7 @@ class ResNet(nn.Module):
block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width,
output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down,
down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, first_conv_stride=first_conv_stride, **block_args)
drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args)
for stage in stage_modules:
self.add_module(*stage) # layer1, layer2, etc
self.feature_info.extend(stage_feature_info)
@ -1078,7 +1078,7 @@ def ecaresnet50d(pretrained=False, **kwargs):
@register_model
def resnetrs50(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('resnetrs50', pretrained, **model_args)
@ -1086,7 +1086,7 @@ def resnetrs50(pretrained=False, **kwargs):
@register_model
def resnetrs101(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('resnetrs101', pretrained, **model_args)
@ -1094,7 +1094,7 @@ def resnetrs101(pretrained=False, **kwargs):
@register_model
def resnetrs152(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('resnetrs152', pretrained, **model_args)
@ -1102,7 +1102,7 @@ def resnetrs152(pretrained=False, **kwargs):
@register_model
def resnetrs200(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('resnetrs200', pretrained, **model_args)
@ -1110,7 +1110,7 @@ def resnetrs200(pretrained=False, **kwargs):
@register_model
def resnetrs270(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('resnetrs270', pretrained, **model_args)
@ -1119,7 +1119,7 @@ def resnetrs270(pretrained=False, **kwargs):
@register_model
def resnetrs350(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('resnetrs350', pretrained, **model_args)
@ -1127,7 +1127,7 @@ def resnetrs350(pretrained=False, **kwargs):
@register_model
def resnetrs420(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', skip_stem_max_pool=True,
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('resnetrs420', pretrained, **model_args)

Loading…
Cancel
Save