|
|
@ -8,7 +8,9 @@ Additionally, supports non pre-activation bottleneck for use as a backbone for V
|
|
|
|
extra padding support to allow porting of official Hybrid ResNet pretrained weights from
|
|
|
|
extra padding support to allow porting of official Hybrid ResNet pretrained weights from
|
|
|
|
https://github.com/google-research/vision_transformer
|
|
|
|
https://github.com/google-research/vision_transformer
|
|
|
|
|
|
|
|
|
|
|
|
Thanks to the Google team for the above two repositories and associated papers.
|
|
|
|
Thanks to the Google team for the above two repositories and associated papers:
|
|
|
|
|
|
|
|
* Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370
|
|
|
|
|
|
|
|
* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929
|
|
|
|
|
|
|
|
|
|
|
|
Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
|
|
|
|
Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -86,19 +88,19 @@ default_cfgs = {
|
|
|
|
num_classes=21843),
|
|
|
|
num_classes=21843),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# trained on imagenet-1k
|
|
|
|
# trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now
|
|
|
|
'resnetv2_50x1_bits': _cfg(
|
|
|
|
# 'resnetv2_50x1_bits': _cfg(
|
|
|
|
url='https://storage.googleapis.com/bit_models/BiT-S-R50x1-ILSVRC2012.npz'),
|
|
|
|
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz'),
|
|
|
|
'resnetv2_50x3_bits': _cfg(
|
|
|
|
# 'resnetv2_50x3_bits': _cfg(
|
|
|
|
url='https://storage.googleapis.com/bit_models/BiT-S-R50x3-ILSVRC2012.npz'),
|
|
|
|
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x3.npz'),
|
|
|
|
'resnetv2_101x1_bits': _cfg(
|
|
|
|
# 'resnetv2_101x1_bits': _cfg(
|
|
|
|
url='https://storage.googleapis.com/bit_models/BiT-S-R101x3-ILSVRC2012.npz'),
|
|
|
|
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
|
|
|
|
'resnetv2_101x3_bits': _cfg(
|
|
|
|
# 'resnetv2_101x3_bits': _cfg(
|
|
|
|
url='https://storage.googleapis.com/bit_models/BiT-S-R101x3-ILSVRC2012.npz'),
|
|
|
|
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
|
|
|
|
'resnetv2_152x2_bits': _cfg(
|
|
|
|
# 'resnetv2_152x2_bits': _cfg(
|
|
|
|
url='https://storage.googleapis.com/bit_models/BiT-S-R152x2-ILSVRC2012.npz'),
|
|
|
|
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x2.npz'),
|
|
|
|
'resnetv2_152x4_bits': _cfg(
|
|
|
|
# 'resnetv2_152x4_bits': _cfg(
|
|
|
|
url='https://storage.googleapis.com/bit_models/BiT-S-R152x4-ILSVRC2012.npz'),
|
|
|
|
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x4.npz'),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -358,8 +360,8 @@ class ResNetV2(nn.Module):
|
|
|
|
self.feature_info = []
|
|
|
|
self.feature_info = []
|
|
|
|
stem_chs = make_div(stem_chs * wf)
|
|
|
|
stem_chs = make_div(stem_chs * wf)
|
|
|
|
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
|
|
|
|
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
|
|
|
|
if not preact:
|
|
|
|
# NOTE no, reduction 2 feature if preact
|
|
|
|
self.feature_info.append(dict(num_chs=stem_chs, reduction=4, module='stem'))
|
|
|
|
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm'))
|
|
|
|
|
|
|
|
|
|
|
|
prev_chs = stem_chs
|
|
|
|
prev_chs = stem_chs
|
|
|
|
curr_stride = 4
|
|
|
|
curr_stride = 4
|
|
|
@ -372,21 +374,19 @@ class ResNetV2(nn.Module):
|
|
|
|
if curr_stride >= output_stride:
|
|
|
|
if curr_stride >= output_stride:
|
|
|
|
dilation *= stride
|
|
|
|
dilation *= stride
|
|
|
|
stride = 1
|
|
|
|
stride = 1
|
|
|
|
if preact:
|
|
|
|
|
|
|
|
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}.norm1')]
|
|
|
|
|
|
|
|
stage = ResNetStage(
|
|
|
|
stage = ResNetStage(
|
|
|
|
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
|
|
|
|
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
|
|
|
|
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_fn=block_fn)
|
|
|
|
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_fn=block_fn)
|
|
|
|
prev_chs = out_chs
|
|
|
|
prev_chs = out_chs
|
|
|
|
curr_stride *= stride
|
|
|
|
curr_stride *= stride
|
|
|
|
if not preact:
|
|
|
|
feat_name = f'stages.{stage_idx}'
|
|
|
|
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]
|
|
|
|
if preact:
|
|
|
|
|
|
|
|
feat_name = f'stages.{stage_idx + 1}.blocks.0.norm1' if (stage_idx + 1) != len(channels) else 'norm'
|
|
|
|
|
|
|
|
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=feat_name)]
|
|
|
|
self.stages.add_module(str(stage_idx), stage)
|
|
|
|
self.stages.add_module(str(stage_idx), stage)
|
|
|
|
|
|
|
|
|
|
|
|
self.num_features = prev_chs
|
|
|
|
self.num_features = prev_chs
|
|
|
|
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
|
|
|
|
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
|
|
|
|
if preact:
|
|
|
|
|
|
|
|
self.feature_info += [dict(num_chs=self.num_features, reduction=curr_stride, module=f'norm')]
|
|
|
|
|
|
|
|
self.head = ClassifierHead(
|
|
|
|
self.head = ClassifierHead(
|
|
|
|
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
|
|
|
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
|
|
|
|
|
|
|
|
|
|
@ -446,9 +446,15 @@ class ResNetV2(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def _create_resnetv2(variant, pretrained=False, **kwargs):
|
|
|
|
def _create_resnetv2(variant, pretrained=False, **kwargs):
|
|
|
|
# FIXME feature map extraction is not setup properly for pre-activation mode right now
|
|
|
|
# FIXME feature map extraction is not setup properly for pre-activation mode right now
|
|
|
|
|
|
|
|
preact = kwargs.get('preact', True)
|
|
|
|
|
|
|
|
feature_cfg = dict(flatten_sequential=True)
|
|
|
|
|
|
|
|
if preact:
|
|
|
|
|
|
|
|
feature_cfg['feature_cls'] = 'hook'
|
|
|
|
|
|
|
|
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for preact
|
|
|
|
|
|
|
|
|
|
|
|
return build_model_with_cfg(
|
|
|
|
return build_model_with_cfg(
|
|
|
|
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
|
|
|
|
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
|
|
|
|
feature_cfg=dict(flatten_sequential=True), **kwargs)
|
|
|
|
feature_cfg=feature_cfg, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
@ -496,83 +502,85 @@ def resnetv2_152x4_bitm(pretrained=False, **kwargs):
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
|
|
|
|
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
|
|
|
|
return _create_resnetv2(
|
|
|
|
return _create_resnetv2(
|
|
|
|
'resnetv2_50x1_bitm', pretrained=pretrained,
|
|
|
|
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
|
|
|
|
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
|
|
|
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
|
|
|
|
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
|
|
|
|
return _create_resnetv2(
|
|
|
|
return _create_resnetv2(
|
|
|
|
'resnetv2_50x3_bitm', pretrained=pretrained,
|
|
|
|
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
|
|
|
|
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
|
|
|
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
|
|
|
|
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
|
|
|
|
return _create_resnetv2(
|
|
|
|
return _create_resnetv2(
|
|
|
|
'resnetv2_101x1_bitm', pretrained=pretrained,
|
|
|
|
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
|
|
|
|
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
|
|
|
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
|
|
|
|
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
|
|
|
|
return _create_resnetv2(
|
|
|
|
return _create_resnetv2(
|
|
|
|
'resnetv2_101x3_bitm', pretrained=pretrained,
|
|
|
|
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
|
|
|
|
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
|
|
|
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
|
|
|
|
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
|
|
|
|
return _create_resnetv2(
|
|
|
|
return _create_resnetv2(
|
|
|
|
'resnetv2_152x2_bitm', pretrained=pretrained,
|
|
|
|
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
|
|
|
|
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
|
|
|
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
|
|
|
|
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
|
|
|
|
return _create_resnetv2(
|
|
|
|
return _create_resnetv2(
|
|
|
|
'resnetv2_152x4_bitm', pretrained=pretrained,
|
|
|
|
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
|
|
|
|
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
|
|
|
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
# NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M.
|
|
|
|
def resnetv2_50x1_bits(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
|
|
|
'resnetv2_50x1_bits', pretrained=pretrained,
|
|
|
|
|
|
|
|
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def resnetv2_50x3_bits(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
|
|
|
'resnetv2_50x3_bits', pretrained=pretrained,
|
|
|
|
|
|
|
|
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def resnetv2_101x1_bits(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
|
|
|
'resnetv2_101x1_bits', pretrained=pretrained,
|
|
|
|
|
|
|
|
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def resnetv2_101x3_bits(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
|
|
|
'resnetv2_101x3_bits', pretrained=pretrained,
|
|
|
|
|
|
|
|
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def resnetv2_152x2_bits(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
|
|
|
'resnetv2_152x2_bits', pretrained=pretrained,
|
|
|
|
|
|
|
|
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def resnetv2_152x4_bits(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
return _create_resnetv2(
|
|
|
|
|
|
|
|
'resnetv2_152x4_bits', pretrained=pretrained,
|
|
|
|
|
|
|
|
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @register_model
|
|
|
|
|
|
|
|
# def resnetv2_50x1_bits(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
# return _create_resnetv2(
|
|
|
|
|
|
|
|
# 'resnetv2_50x1_bits', pretrained=pretrained,
|
|
|
|
|
|
|
|
# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# @register_model
|
|
|
|
|
|
|
|
# def resnetv2_50x3_bits(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
# return _create_resnetv2(
|
|
|
|
|
|
|
|
# 'resnetv2_50x3_bits', pretrained=pretrained,
|
|
|
|
|
|
|
|
# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# @register_model
|
|
|
|
|
|
|
|
# def resnetv2_101x1_bits(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
# return _create_resnetv2(
|
|
|
|
|
|
|
|
# 'resnetv2_101x1_bits', pretrained=pretrained,
|
|
|
|
|
|
|
|
# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# @register_model
|
|
|
|
|
|
|
|
# def resnetv2_101x3_bits(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
# return _create_resnetv2(
|
|
|
|
|
|
|
|
# 'resnetv2_101x3_bits', pretrained=pretrained,
|
|
|
|
|
|
|
|
# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# @register_model
|
|
|
|
|
|
|
|
# def resnetv2_152x2_bits(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
# return _create_resnetv2(
|
|
|
|
|
|
|
|
# 'resnetv2_152x2_bits', pretrained=pretrained,
|
|
|
|
|
|
|
|
# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# @register_model
|
|
|
|
|
|
|
|
# def resnetv2_152x4_bits(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
# return _create_resnetv2(
|
|
|
|
|
|
|
|
# 'resnetv2_152x4_bits', pretrained=pretrained,
|
|
|
|
|
|
|
|
# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
|
|
|
|
|
|
|
#
|
|
|
|