From 844993267b6b59abe706f65d9a1a1259290115fa Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Sat, 10 Apr 2021 05:28:45 -0400 Subject: [PATCH] Add ResNet-RS models --- timm/models/resnet.py | 99 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 89 insertions(+), 10 deletions(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 656e3a51..f012f797 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -233,7 +233,23 @@ default_cfgs = { interpolation='bicubic'), 'resnetblur50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth', - interpolation='bicubic') + interpolation='bicubic'), + + # ResNet-RS models + 'resnetrs50': _cfg( + interpolation='bicubic'), + 'resnetrs101': _cfg( + interpolation='bicubic'), + 'resnetrs152': _cfg( + interpolation='bicubic'), + 'resnetrs200': _cfg( + interpolation='bicubic'), + 'resnetrs270': _cfg( + interpolation='bicubic'), + 'resnetrs350': _cfg( + interpolation='bicubic'), + 'resnetrs420': _cfg( + interpolation='bicubic'), } @@ -426,7 +442,7 @@ def drop_blocks(drop_block_rate=0.): def make_blocks( block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32, - down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs): + down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., first_conv_stride=1, **kwargs): stages = [] feature_info = [] net_num_blocks = sum(block_repeats) @@ -435,7 +451,7 @@ def make_blocks( dilation = prev_dilation = 1 for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))): stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it - stride = 1 if stage_idx == 0 else 2 + stride = first_conv_stride if stage_idx == 0 else 2 if net_stride >= output_stride: dilation *= stride stride = 1 @@ -542,11 +558,12 @@ class ResNet(nn.Module): cardinality=1, base_width=64, stem_width=64, stem_type='', output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., - drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): + drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None, skip_stem_max_pool=False): block_args = block_args or dict() assert output_stride in (8, 16, 32) self.num_classes = num_classes self.drop_rate = drop_rate + self.skip_stem_max_pool = skip_stem_max_pool super(ResNet, self).__init__() # Stem @@ -571,12 +588,17 @@ class ResNet(nn.Module): self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] # Stem Pooling - if aa_layer is not None: - self.maxpool = nn.Sequential(*[ - nn.MaxPool2d(kernel_size=3, stride=1, padding=1), - aa_layer(channels=inplanes, stride=2)]) + if not self.skip_stem_max_pool: + first_conv_stride = 1 + if aa_layer is not None: + self.maxpool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + aa_layer(channels=inplanes, stride=2)]) + else: + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.maxpool = nn.Identity() + first_conv_stride = 2 # Feature Blocks channels = [64, 128, 256, 512] @@ -584,7 +606,7 @@ class ResNet(nn.Module): block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width, output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down, down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, - drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args) + drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, first_conv_stride=first_conv_stride, **block_args) for stage in stage_modules: self.add_module(*stage) # layer1, layer2, etc self.feature_info.extend(stage_feature_info) @@ -1053,6 +1075,63 @@ def ecaresnet50d(pretrained=False, **kwargs): return _create_resnet('ecaresnet50d', pretrained, **model_args) +@register_model +def resnetrs50(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('resnetrs50', pretrained, **model_args) + + +@register_model +def resnetrs101(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('resnetrs101', pretrained, **model_args) + + +@register_model +def resnetrs152(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('resnetrs152', pretrained, **model_args) + + +@register_model +def resnetrs200(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', skip_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('resnetrs200', pretrained, **model_args) + + +@register_model +def resnetrs270(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', skip_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('resnetrs270', pretrained, **model_args) + + + +@register_model +def resnetrs350(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', skip_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('resnetrs350', pretrained, **model_args) + + +@register_model +def resnetrs420(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', skip_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('resnetrs420', pretrained, **model_args) + + @register_model def ecaresnet50d_pruned(pretrained=False, **kwargs): """Constructs a ResNet-50-D model pruned with eca.