diff --git a/timm/models/deit.py b/timm/models/deit.py index 8cb36bd6..3205b024 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -64,6 +64,8 @@ default_cfgs = { 'deit3_small_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth', input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_medium_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_1k.pth'), 'deit3_base_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'), 'deit3_base_patch16_384': _cfg( @@ -83,6 +85,9 @@ default_cfgs = { 'deit3_small_patch16_384_in21ft1k': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth', input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_medium_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_21k.pth', + crop_pct=1.0), 'deit3_base_patch16_224_in21ft1k': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth', crop_pct=1.0), @@ -290,6 +295,17 @@ def deit3_small_patch16_384(pretrained=False, **kwargs): return model +@register_model +def deit3_medium_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 medium model @ 224x224 (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_medium_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def deit3_base_patch16_224(pretrained=False, **kwargs): """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). @@ -367,6 +383,17 @@ def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs): return model +@register_model +def deit3_medium_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 medium model @ 224x224 (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_medium_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + @register_model def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs): """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).