Post ResNet-RS merge cleanup. Add weight urls, adjust train/test/crop pct.

pull/609/head
Ross Wightman 3 years ago
parent 560eae38f5
commit 67d0665b46

@ -15,16 +15,16 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
torch._C._jit_set_profiling_mode(False)
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*','coat_*']
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*']
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm',
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*',
'*resnetrs200*', '*resnetrs270*', '*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS
'*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS
else:
EXCLUDE_FILTERS = NON_STD_FILTERS

@ -7,6 +7,7 @@ ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered ste
Copyright 2020 Ross Wightman
"""
import math
from functools import partial
import torch
import torch.nn as nn
@ -14,7 +15,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, create_attn, create_classifier
from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, create_attn, get_attn, create_classifier
from .registry import register_model
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
@ -240,18 +241,32 @@ default_cfgs = {
# ResNet-RS models
'resnetrs50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50-7c9728e2.pth',
input_size=(3, 160, 160), pool_size=(4, 4), crop_pct=0.91, test_input_size=(3, 224, 224),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs101': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101-3e4bb55c.pth',
input_size=(3, 192, 192), pool_size=(6, 6), crop_pct=0.94, test_input_size=(3, 288, 288),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs152': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152-b1efe56d.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs200': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200-b455b791.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs270': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270-cafcfbc7.pth',
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs350': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350-06d9bfac.pth',
input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, test_input_size=(3, 384, 384),
interpolation='bicubic', first_conv='conv1.0'),
'resnetrs420': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420-d26764a5.pth',
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, test_input_size=(3, 416, 416),
interpolation='bicubic', first_conv='conv1.0'),
}
@ -334,7 +349,7 @@ class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None, **kwargs):
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality)
@ -357,7 +372,7 @@ class Bottleneck(nn.Module):
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.se = create_attn(attn_layer, outplanes, **kwargs)
self.se = create_attn(attn_layer, outplanes)
self.act3 = act_layer(inplace=True)
self.downsample = downsample
@ -558,15 +573,14 @@ class ResNet(nn.Module):
"""
def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64, stem_width=64, stem_type='',
cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False,
output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None, replace_stem_max_pool=False):
drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
block_args = block_args or dict()
assert output_stride in (8, 16, 32)
self.num_classes = num_classes
self.drop_rate = drop_rate
self.replace_stem_max_pool = replace_stem_max_pool
super(ResNet, self).__init__()
# Stem
@ -591,19 +605,20 @@ class ResNet(nn.Module):
self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
# Stem Pooling
if not self.replace_stem_max_pool:
if replace_stem_pool:
self.maxpool = nn.Sequential(*filter(None, [
nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False),
aa_layer(channels=inplanes, stride=2) if aa_layer else None,
norm_layer(inplanes),
act_layer(inplace=True)
]))
else:
if aa_layer is not None:
self.maxpool = nn.Sequential(*[
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
aa_layer(channels=inplanes, stride=2)])
else:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
else:
self.maxpool = nn.Sequential(*[
nn.Conv2d(inplanes, inplanes, 3, stride=2, padding=1, bias=False),
norm_layer(inplanes),
act_layer(inplace=True)
])
# Feature Blocks
channels = [64, 128, 256, 512]
@ -1091,58 +1106,93 @@ def ecaresnet50d(pretrained=False, **kwargs):
@register_model
def resnetrs50(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-50 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs)
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs50', pretrained, **model_args)
@register_model
def resnetrs101(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-101 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs)
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs101', pretrained, **model_args)
@register_model
def resnetrs152(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-152 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs)
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs152', pretrained, **model_args)
@register_model
def resnetrs200(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-200 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs)
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs200', pretrained, **model_args)
@register_model
def resnetrs270(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-270 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs)
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs270', pretrained, **model_args)
@register_model
def resnetrs350(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-350 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs)
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs350', pretrained, **model_args)
@register_model
def resnetrs420(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-420 model
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_max_pool=True,
avg_down=True, block_args=dict(attn_layer='se', reduction_ratio=0.25), **kwargs)
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs420', pretrained, **model_args)

Loading…
Cancel
Save