Add AvgPool2d anti-aliasing support to ResNet arch (as per OpenAI CLIP models), add a few blur aa models as well

pull/1094/head
Ross Wightman 3 years ago
parent f0f9eccda8
commit 1aa617cb3b

@ -251,6 +251,21 @@ default_cfgs = {
'resnetblur50': _cfg( 'resnetblur50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth',
interpolation='bicubic'), interpolation='bicubic'),
'resnetblur50d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
'resnetblur101d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
'resnetaa50d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
'resnetaa101d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
'seresnetaa50d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
# ResNet-RS models # ResNet-RS models
'resnetrs50': _cfg( 'resnetrs50': _cfg(
@ -289,6 +304,12 @@ def get_padding(kernel_size, stride, dilation=1):
return padding return padding
def create_aa(aa_layer, channels, stride=2, enable=True):
if not aa_layer or not enable:
return None
return aa_layer(stride) if issubclass(aa_layer, nn.AvgPool2d) else aa_layer(channels=channels, stride=stride)
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
expansion = 1 expansion = 1
@ -309,7 +330,7 @@ class BasicBlock(nn.Module):
dilation=first_dilation, bias=False) dilation=first_dilation, bias=False)
self.bn1 = norm_layer(first_planes) self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None self.aa = create_aa(aa_layer, channels=first_planes, stride=stride, enable=use_aa)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
@ -380,7 +401,7 @@ class Bottleneck(nn.Module):
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width) self.bn2 = norm_layer(width)
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
self.aa = aa_layer(channels=width, stride=stride) if use_aa else None self.aa = create_aa(aa_layer, channels=width, stride=stride, enable=use_aa)
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes) self.bn3 = norm_layer(outplanes)
@ -617,19 +638,22 @@ class ResNet(nn.Module):
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
# Stem Pooling # Stem pooling. The name 'maxpool' remains for weight compatibility.
if replace_stem_pool: if replace_stem_pool:
self.maxpool = nn.Sequential(*filter(None, [ self.maxpool = nn.Sequential(*filter(None, [
nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False), nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False),
aa_layer(channels=inplanes, stride=2) if aa_layer else None, create_aa(aa_layer, channels=inplanes, stride=2),
norm_layer(inplanes), norm_layer(inplanes),
act_layer(inplace=True) act_layer(inplace=True)
])) ]))
else: else:
if aa_layer is not None: if aa_layer is not None:
self.maxpool = nn.Sequential(*[ if issubclass(aa_layer, nn.AvgPool2d):
nn.MaxPool2d(kernel_size=3, stride=1, padding=1), self.maxpool = aa_layer(2)
aa_layer(channels=inplanes, stride=2)]) else:
self.maxpool = nn.Sequential(*[
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
aa_layer(channels=inplanes, stride=2)])
else: else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
@ -1342,6 +1366,56 @@ def resnetblur50(pretrained=False, **kwargs):
return _create_resnet('resnetblur50', pretrained, **model_args) return _create_resnet('resnetblur50', pretrained, **model_args)
@register_model
def resnetblur50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with blur anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetblur50d', pretrained, **model_args)
@register_model
def resnetblur101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model with blur anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetblur101d', pretrained, **model_args)
@register_model
def resnetaa50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetaa50d', pretrained, **model_args)
@register_model
def resnetaa101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetaa101d', pretrained, **model_args)
@register_model
def seresnetaa50d(pretrained=False, **kwargs):
"""Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnetaa50d', pretrained, **model_args)
@register_model @register_model
def seresnet18(pretrained=False, **kwargs): def seresnet18(pretrained=False, **kwargs):
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs) model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs)

Loading…
Cancel
Save