From 560eae38f509f60f3effb2257bcc598dbcb67185 Mon Sep 17 00:00:00 2001 From: Aman Arora <41290559+amaarora@users.noreply.github.com> Date: Wed, 5 May 2021 03:59:44 +1000 Subject: [PATCH] [WIP] Add ResNet-RS models (#554) * Add ResNet-RS models * Only include resnet-rs changes * remove whitespace diff * EOF newline * Update time * increase time * Add first conv * Try running only resnetv2_101x1_bitm on Linux runner * Add to exclude filter * Run test_model_forward_features for all * Add to exclude ftrs * back to defaults * only run test_forward_features * run all tests * Run all tests * Add bigger resnetrs to model filters to fix Github CLI * Remove resnetv2_101x1_bitm from exclude feat features * Remove hardcoded values * Make sure reduction ratio in resnetrs is 0.25 * There is no bias in replaced maxpool so remove it --- tests/test_models.py | 5 ++- timm/models/resnet.py | 99 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 93 insertions(+), 11 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 1dee97e7..3da0f872 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -22,8 +22,9 @@ NUM_NON_STD = len(NON_STD_FILTERS) if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models EXCLUDE_FILTERS = [ - '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', - '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*'] + NON_STD_FILTERS + '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', + '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', + '*resnetrs200*', '*resnetrs270*', '*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS else: EXCLUDE_FILTERS = NON_STD_FILTERS diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 2b38b963..377d2d97 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -236,7 +236,23 @@ default_cfgs = { interpolation='bicubic'), 'resnetblur50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth', - interpolation='bicubic') + interpolation='bicubic'), + + # ResNet-RS models + 'resnetrs50': _cfg( + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs101': _cfg( + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs152': _cfg( + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs200': _cfg( + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs270': _cfg( + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs350': _cfg( + interpolation='bicubic', first_conv='conv1.0'), + 'resnetrs420': _cfg( + interpolation='bicubic', first_conv='conv1.0'), } @@ -318,7 +334,7 @@ class Bottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None, **kwargs): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) @@ -341,7 +357,7 @@ class Bottleneck(nn.Module): self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) - self.se = create_attn(attn_layer, outplanes) + self.se = create_attn(attn_layer, outplanes, **kwargs) self.act3 = act_layer(inplace=True) self.downsample = downsample @@ -545,11 +561,12 @@ class ResNet(nn.Module): cardinality=1, base_width=64, stem_width=64, stem_type='', output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., - drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): + drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None, replace_stem_max_pool=False): block_args = block_args or dict() assert output_stride in (8, 16, 32) self.num_classes = num_classes self.drop_rate = drop_rate + self.replace_stem_max_pool = replace_stem_max_pool super(ResNet, self).__init__() # Stem @@ -574,12 +591,19 @@ class ResNet(nn.Module): self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] # Stem Pooling - if aa_layer is not None: - self.maxpool = nn.Sequential(*[ - nn.MaxPool2d(kernel_size=3, stride=1, padding=1), - aa_layer(channels=inplanes, stride=2)]) + if not self.replace_stem_max_pool: + if aa_layer is not None: + self.maxpool = nn.Sequential(*[ + nn.MaxPool2d(kernel_size=3, stride=1, padding=1), + aa_layer(channels=inplanes, stride=2)]) + else: + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.maxpool = nn.Sequential(*[ + nn.Conv2d(inplanes, inplanes, 3, stride=2, padding=1, bias=False), + norm_layer(inplanes), + act_layer(inplace=True) + ]) # Feature Blocks channels = [64, 128, 256, 512] @@ -1065,6 +1089,63 @@ def ecaresnet50d(pretrained=False, **kwargs): return _create_resnet('ecaresnet50d', pretrained, **model_args) +@register_model +def resnetrs50(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs) + return _create_resnet('resnetrs50', pretrained, **model_args) + + +@register_model +def resnetrs101(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs) + return _create_resnet('resnetrs101', pretrained, **model_args) + + +@register_model +def resnetrs152(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs) + return _create_resnet('resnetrs152', pretrained, **model_args) + + +@register_model +def resnetrs200(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs) + return _create_resnet('resnetrs200', pretrained, **model_args) + + +@register_model +def resnetrs270(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs) + return _create_resnet('resnetrs270', pretrained, **model_args) + + + +@register_model +def resnetrs350(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs) + return _create_resnet('resnetrs350', pretrained, **model_args) + + +@register_model +def resnetrs420(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True, + avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs) + return _create_resnet('resnetrs420', pretrained, **model_args) + + @register_model def ecaresnet50d_pruned(pretrained=False, **kwargs): """Constructs a ResNet-50-D model pruned with eca.