Add 21k weight urls to vision_transformer. Cleanup feature_info for preact ResNetV2 (BiT) models

pull/323/head
Ross Wightman 4 years ago
parent 231d04e91a
commit ce69de70d3

@ -8,7 +8,9 @@ Additionally, supports non pre-activation bottleneck for use as a backbone for V
extra padding support to allow porting of official Hybrid ResNet pretrained weights from extra padding support to allow porting of official Hybrid ResNet pretrained weights from
https://github.com/google-research/vision_transformer https://github.com/google-research/vision_transformer
Thanks to the Google team for the above two repositories and associated papers. Thanks to the Google team for the above two repositories and associated papers:
* Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370
* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929
Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020. Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
""" """
@ -86,19 +88,19 @@ default_cfgs = {
num_classes=21843), num_classes=21843),
# trained on imagenet-1k # trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now
'resnetv2_50x1_bits': _cfg( # 'resnetv2_50x1_bits': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-S-R50x1-ILSVRC2012.npz'), # url='https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz'),
'resnetv2_50x3_bits': _cfg( # 'resnetv2_50x3_bits': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-S-R50x3-ILSVRC2012.npz'), # url='https://storage.googleapis.com/bit_models/BiT-S-R50x3.npz'),
'resnetv2_101x1_bits': _cfg( # 'resnetv2_101x1_bits': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-S-R101x3-ILSVRC2012.npz'), # url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
'resnetv2_101x3_bits': _cfg( # 'resnetv2_101x3_bits': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-S-R101x3-ILSVRC2012.npz'), # url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
'resnetv2_152x2_bits': _cfg( # 'resnetv2_152x2_bits': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-S-R152x2-ILSVRC2012.npz'), # url='https://storage.googleapis.com/bit_models/BiT-S-R152x2.npz'),
'resnetv2_152x4_bits': _cfg( # 'resnetv2_152x4_bits': _cfg(
url='https://storage.googleapis.com/bit_models/BiT-S-R152x4-ILSVRC2012.npz'), # url='https://storage.googleapis.com/bit_models/BiT-S-R152x4.npz'),
} }
@ -358,8 +360,8 @@ class ResNetV2(nn.Module):
self.feature_info = [] self.feature_info = []
stem_chs = make_div(stem_chs * wf) stem_chs = make_div(stem_chs * wf)
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
if not preact: # NOTE no, reduction 2 feature if preact
self.feature_info.append(dict(num_chs=stem_chs, reduction=4, module='stem')) self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm'))
prev_chs = stem_chs prev_chs = stem_chs
curr_stride = 4 curr_stride = 4
@ -372,21 +374,19 @@ class ResNetV2(nn.Module):
if curr_stride >= output_stride: if curr_stride >= output_stride:
dilation *= stride dilation *= stride
stride = 1 stride = 1
if preact:
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}.norm1')]
stage = ResNetStage( stage = ResNetStage(
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down, prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_fn=block_fn) act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_fn=block_fn)
prev_chs = out_chs prev_chs = out_chs
curr_stride *= stride curr_stride *= stride
if not preact: feat_name = f'stages.{stage_idx}'
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')] if preact:
feat_name = f'stages.{stage_idx + 1}.blocks.0.norm1' if (stage_idx + 1) != len(channels) else 'norm'
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=feat_name)]
self.stages.add_module(str(stage_idx), stage) self.stages.add_module(str(stage_idx), stage)
self.num_features = prev_chs self.num_features = prev_chs
self.norm = norm_layer(self.num_features) if preact else nn.Identity() self.norm = norm_layer(self.num_features) if preact else nn.Identity()
if preact:
self.feature_info += [dict(num_chs=self.num_features, reduction=curr_stride, module=f'norm')]
self.head = ClassifierHead( self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
@ -446,9 +446,15 @@ class ResNetV2(nn.Module):
def _create_resnetv2(variant, pretrained=False, **kwargs): def _create_resnetv2(variant, pretrained=False, **kwargs):
# FIXME feature map extraction is not setup properly for pre-activation mode right now # FIXME feature map extraction is not setup properly for pre-activation mode right now
preact = kwargs.get('preact', True)
feature_cfg = dict(flatten_sequential=True)
if preact:
feature_cfg['feature_cls'] = 'hook'
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for preact
return build_model_with_cfg( return build_model_with_cfg(
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True, ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
feature_cfg=dict(flatten_sequential=True), **kwargs) feature_cfg=feature_cfg, **kwargs)
@register_model @register_model
@ -496,83 +502,85 @@ def resnetv2_152x4_bitm(pretrained=False, **kwargs):
@register_model @register_model
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs): def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2(
'resnetv2_50x1_bitm', pretrained=pretrained, 'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model @register_model
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs): def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2(
'resnetv2_50x3_bitm', pretrained=pretrained, 'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs) layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model @register_model
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs): def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2(
'resnetv2_101x1_bitm', pretrained=pretrained, 'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs) layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model @register_model
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs): def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2(
'resnetv2_101x3_bitm', pretrained=pretrained, 'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs) layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model @register_model
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs): def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2(
'resnetv2_152x2_bitm', pretrained=pretrained, 'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs) layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
@register_model @register_model
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs): def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2(
'resnetv2_152x4_bitm', pretrained=pretrained, 'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.get('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs) layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
@register_model # NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M.
def resnetv2_50x1_bits(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x1_bits', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_50x3_bits(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50x3_bits', pretrained=pretrained,
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x1_bits(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x1_bits', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model
def resnetv2_101x3_bits(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101x3_bits', pretrained=pretrained,
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x2_bits(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x2_bits', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
@register_model
def resnetv2_152x4_bits(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152x4_bits', pretrained=pretrained,
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
# @register_model
# def resnetv2_50x1_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_50x1_bits', pretrained=pretrained,
# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_50x3_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_50x3_bits', pretrained=pretrained,
# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_101x1_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_101x1_bits', pretrained=pretrained,
# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_101x3_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_101x3_bits', pretrained=pretrained,
# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_152x2_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_152x2_bits', pretrained=pretrained,
# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_152x4_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_152x4_bits', pretrained=pretrained,
# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
#

@ -79,23 +79,27 @@ default_cfgs = {
# patch models, imagenet21k (weights ported from official JAX impl) # patch models, imagenet21k (weights ported from official JAX impl)
'vit_base_patch16_224_in21k': _cfg( 'vit_base_patch16_224_in21k': _cfg(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_base_patch32_224_in21k': _cfg( 'vit_base_patch32_224_in21k': _cfg(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch16_224_in21k': _cfg( 'vit_large_patch16_224_in21k': _cfg(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch32_224_in21k': _cfg( 'vit_large_patch32_224_in21k': _cfg(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_huge_patch14_224_in21k': _cfg( 'vit_huge_patch14_224_in21k': _cfg(
url='', url='', # FIXME I have weights for this but > 2GB limit for github release binaries
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# hybrid models (weights ported from official JAX impl) # hybrid models (weights ported from official JAX impl)
'vit_base_resnet50_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9),
'vit_base_resnet50_384': _cfg( 'vit_base_resnet50_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
# hybrid models (my experiments) # hybrid models (my experiments)
@ -269,6 +273,7 @@ class VisionTransformer(nn.Module):
# Representation layer # Representation layer
if representation_size: if representation_size:
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([ self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)), ('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh()) ('act', nn.Tanh())
@ -315,12 +320,12 @@ class VisionTransformer(nn.Module):
for blk in self.blocks: for blk in self.blocks:
x = blk(x) x = blk(x)
x = self.norm(x) x = self.norm(x)[:, 0]
return x[:, 0] x = self.pre_logits(x)
return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = self.pre_logits(x)
x = self.head(x) x = self.head(x)
return x return x
@ -426,22 +431,12 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
return model return model
@register_model
def vit_large_patch32_384(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch32_384']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model @register_model
def vit_base_patch16_224_in21k(pretrained=False, **kwargs): def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.get('num_classes', 21843)
model = VisionTransformer( model = VisionTransformer(
patch_size=16, num_classes=21843, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, patch_size=16, num_classes=num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch16_224_in21k'] model.default_cfg = default_cfgs['vit_base_patch16_224_in21k']
if pretrained: if pretrained:
load_pretrained( load_pretrained(
@ -451,9 +446,10 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
@register_model @register_model
def vit_base_patch32_224_in21k(pretrained=False, **kwargs): def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.get('num_classes', 21843)
model = VisionTransformer( model = VisionTransformer(
img_size=224, num_classes=21843, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, img_size=224, num_classes=num_classes, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) qkv_bias=True, representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_patch32_224_in21k'] model.default_cfg = default_cfgs['vit_base_patch32_224_in21k']
if pretrained: if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
@ -462,9 +458,10 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
@register_model @register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs): def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.get('num_classes', 21843)
model = VisionTransformer( model = VisionTransformer(
patch_size=16, num_classes=21843, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, patch_size=16, num_classes=num_classes, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) representation_size=1024, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch16_224_in21k'] model.default_cfg = default_cfgs['vit_large_patch16_224_in21k']
if pretrained: if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
@ -473,9 +470,10 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
@register_model @register_model
def vit_large_patch32_224_in21k(pretrained=False, **kwargs): def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.get('num_classes', 21843)
model = VisionTransformer( model = VisionTransformer(
img_size=224, num_classes=21843, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, img_size=224, num_classes=num_classes, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) qkv_bias=True, representation_size=1024, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_large_patch32_224_in21k'] model.default_cfg = default_cfgs['vit_large_patch32_224_in21k']
if pretrained: if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
@ -484,15 +482,31 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
@register_model @register_model
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
num_classes = kwargs.get('num_classes', 21843)
model = VisionTransformer( model = VisionTransformer(
img_size=224, patch_size=14, num_classes=21843, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, img_size=224, patch_size=14, num_classes=num_classes, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) qkv_bias=True, representation_size=1280, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_huge_patch14_224_in21k'] model.default_cfg = default_cfgs['vit_huge_patch14_224_in21k']
if pretrained: if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
num_classes = kwargs.get('num_classes', 21843)
backbone = ResNetV2(
layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='')
model = VisionTransformer(
img_size=224, num_classes=num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
hybrid_backbone=backbone, representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = default_cfgs['vit_base_resnet50_224_in21k']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model @register_model
def vit_base_resnet50_384(pretrained=False, **kwargs): def vit_base_resnet50_384(pretrained=False, **kwargs):
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head

@ -60,7 +60,7 @@ parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD'
help='Override std deviation of of dataset') help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME', parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)') help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=1000, parser.add_argument('--num-classes', type=int, default=None,
help='Number classes in dataset') help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")') help='path to class to idx mapping file (default: "")')

Loading…
Cancel
Save