Add DeiT-III 'medium' model defs and weights

pull/1381/head
Ross Wightman 2 years ago
parent 7cd4204a28
commit ec6a28830f

@ -64,6 +64,8 @@ default_cfgs = {
'deit3_small_patch16_384': _cfg( 'deit3_small_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth', url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'deit3_medium_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_1k.pth'),
'deit3_base_patch16_224': _cfg( 'deit3_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'), url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
'deit3_base_patch16_384': _cfg( 'deit3_base_patch16_384': _cfg(
@ -83,6 +85,9 @@ default_cfgs = {
'deit3_small_patch16_384_in21ft1k': _cfg( 'deit3_small_patch16_384_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth', url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'deit3_medium_patch16_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_21k.pth',
crop_pct=1.0),
'deit3_base_patch16_224_in21ft1k': _cfg( 'deit3_base_patch16_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth', url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
crop_pct=1.0), crop_pct=1.0),
@ -290,6 +295,17 @@ def deit3_small_patch16_384(pretrained=False, **kwargs):
return model return model
@register_model
def deit3_medium_patch16_224(pretrained=False, **kwargs):
""" DeiT-3 medium model @ 224x224 (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_medium_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def deit3_base_patch16_224(pretrained=False, **kwargs): def deit3_base_patch16_224(pretrained=False, **kwargs):
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
@ -367,6 +383,17 @@ def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs):
return model return model
@register_model
def deit3_medium_patch16_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 medium model @ 224x224 (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_medium_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs): def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).

Loading…
Cancel
Save