Add two new SE-ResNeXt101-D 32x8d weights, one anti-aliased and one not. Reshuffle default_cfgs vs model entrypoints for resnet.py so they are better aligned.

pull/1230/head
Ross Wightman 2 years ago
parent fbf597049c
commit 7629d8264d

@ -148,6 +148,49 @@ default_cfgs = {
'swsl_resnext101_32x16d': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'),
# Efficient Channel Attention ResNets
'ecaresnet26t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=0.95, test_input_size=(3, 320, 320)),
'ecaresnetlight': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth',
interpolation='bicubic'),
'ecaresnet50d': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet50d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet50t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=0.95, test_input_size=(3, 320, 320)),
'ecaresnet101d': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth',
interpolation='bicubic', first_conv='conv1.0'),
'ecaresnet101d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet200d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'ecaresnet269d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10),
crop_pct=1.0, test_input_size=(3, 352, 352)),
# Efficient Channel Attention ResNeXts
'ecaresnext26t_32x4d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
'ecaresnext50t_32x4d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
# Squeeze-Excitation ResNets, to eventually replace the models in senet.py
'seresnet18': _cfg(
url='',
@ -180,7 +223,6 @@ default_cfgs = {
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
# Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py
'seresnext26d_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth',
@ -199,55 +241,16 @@ default_cfgs = {
'seresnext101_32x8d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101_32x8d_ah-e6bc4c0a.pth',
interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0),
'seresnext101d_32x8d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101d_32x8d_ah-191d7b94.pth',
interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0),
'senet154': _cfg(
url='',
interpolation='bicubic',
first_conv='conv1.0'),
# Efficient Channel Attention ResNets
'ecaresnet26t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=0.95, test_input_size=(3, 320, 320)),
'ecaresnetlight': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth',
interpolation='bicubic'),
'ecaresnet50d': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet50d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet50t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=0.95, test_input_size=(3, 320, 320)),
'ecaresnet101d': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth',
interpolation='bicubic', first_conv='conv1.0'),
'ecaresnet101d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet200d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'ecaresnet269d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10),
crop_pct=1.0, test_input_size=(3, 352, 352)),
# Efficient Channel Attention ResNeXts
'ecaresnext26t_32x4d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
'ecaresnext50t_32x4d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
# ResNets with anti-aliasing blur pool
# ResNets with anti-aliasing / blur pool
'resnetblur18': _cfg(
interpolation='bicubic'),
'resnetblur50': _cfg(
@ -268,6 +271,9 @@ default_cfgs = {
'seresnetaa50d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
'seresnextaa101d_32x8d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnextaa101d_32x8d_ah-83c8ae12.pth',
interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0),
# ResNet-RS models
'resnetrs50': _cfg(
@ -1157,98 +1163,6 @@ def ecaresnet50d(pretrained=False, **kwargs):
return _create_resnet('ecaresnet50d', pretrained, **model_args)
@register_model
def resnetrs50(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-50 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs50', pretrained, **model_args)
@register_model
def resnetrs101(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-101 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs101', pretrained, **model_args)
@register_model
def resnetrs152(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-152 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs152', pretrained, **model_args)
@register_model
def resnetrs200(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-200 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs200', pretrained, **model_args)
@register_model
def resnetrs270(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-270 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs270', pretrained, **model_args)
@register_model
def resnetrs350(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-350 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs350', pretrained, **model_args)
@register_model
def resnetrs420(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-420 model
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs420', pretrained, **model_args)
@register_model
def ecaresnet50d_pruned(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model pruned with eca.
@ -1346,72 +1260,6 @@ def ecaresnext50t_32x4d(pretrained=False, **kwargs):
return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args)
@register_model
def resnetblur18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model with blur anti-aliasing
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs)
return _create_resnet('resnetblur18', pretrained, **model_args)
@register_model
def resnetblur50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with blur anti-aliasing
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs)
return _create_resnet('resnetblur50', pretrained, **model_args)
@register_model
def resnetblur50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with blur anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetblur50d', pretrained, **model_args)
@register_model
def resnetblur101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model with blur anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetblur101d', pretrained, **model_args)
@register_model
def resnetaa50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetaa50d', pretrained, **model_args)
@register_model
def resnetaa101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetaa101d', pretrained, **model_args)
@register_model
def seresnetaa50d(pretrained=False, **kwargs):
"""Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnetaa50d', pretrained, **model_args)
@register_model
def seresnet18(pretrained=False, **kwargs):
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs)
@ -1535,9 +1383,187 @@ def seresnext101_32x8d(pretrained=False, **kwargs):
return _create_resnet('seresnext101_32x8d', pretrained, **model_args)
@register_model
def seresnext101d_32x8d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext101d_32x8d', pretrained, **model_args)
@register_model
def senet154(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep',
down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('senet154', pretrained, **model_args)
@register_model
def resnetblur18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model with blur anti-aliasing
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs)
return _create_resnet('resnetblur18', pretrained, **model_args)
@register_model
def resnetblur50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with blur anti-aliasing
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs)
return _create_resnet('resnetblur50', pretrained, **model_args)
@register_model
def resnetblur50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with blur anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetblur50d', pretrained, **model_args)
@register_model
def resnetblur101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model with blur anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetblur101d', pretrained, **model_args)
@register_model
def resnetaa50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetaa50d', pretrained, **model_args)
@register_model
def resnetaa101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetaa101d', pretrained, **model_args)
@register_model
def seresnetaa50d(pretrained=False, **kwargs):
"""Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnetaa50d', pretrained, **model_args)
@register_model
def seresnextaa101d_32x8d(pretrained=False, **kwargs):
"""Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnextaa101d_32x8d', pretrained, **model_args)
@register_model
def resnetrs50(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-50 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs50', pretrained, **model_args)
@register_model
def resnetrs101(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-101 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs101', pretrained, **model_args)
@register_model
def resnetrs152(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-152 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs152', pretrained, **model_args)
@register_model
def resnetrs200(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-200 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs200', pretrained, **model_args)
@register_model
def resnetrs270(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-270 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs270', pretrained, **model_args)
@register_model
def resnetrs350(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-350 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs350', pretrained, **model_args)
@register_model
def resnetrs420(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-420 model
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs420', pretrained, **model_args)

Loading…
Cancel
Save