From b319eb5b5d8d29d109a1ca33bd0de0a1ac0d329c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 18 Jun 2021 16:16:49 -0700 Subject: [PATCH] Update ViT weights, more details to be added before merge. --- timm/models/vision_transformer.py | 264 ++++++++++++++--------- timm/models/vision_transformer_hybrid.py | 128 +++++------ 2 files changed, 211 insertions(+), 181 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 7dd9137e..b8fc6fa5 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg, named_apply, adapt_input_conv from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .registry import register_model @@ -40,106 +40,116 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } default_cfgs = { - # FIXME weights coming + # patch models (weights from official Google JAX impl) 'vit_tiny_patch16_224': _cfg( - url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - ), + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_tiny_patch16_384': _cfg( - url='', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 - ), - 'vit_small_patch16_224': _cfg( - url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - ), + url='https://storage.googleapis.com/vit_models/augreg/' + 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), 'vit_small_patch32_224': _cfg( - url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - ), - 'vit_small_patch16_384': _cfg( - url='', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 - ), + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_small_patch32_384': _cfg( - url='', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 - ), - - # patch models (weights ported from official Google JAX impl) - 'vit_base_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - ), + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), + 'vit_small_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_patch32_224': _cfg( - url='', # no official model weights for this combo, only for in21k - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - 'vit_base_patch16_384': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 'vit_base_patch32_384': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), - 'vit_large_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_base_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), 'vit_large_patch32_224': _cfg( url='', # no official model weights for this combo, only for in21k - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - 'vit_large_patch16_384': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + ), 'vit_large_patch32_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch16_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), + 'vit_large_patch16_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0), - # patch models, imagenet21k (weights ported from official Google JAX impl) - 'vit_base_patch16_224_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth', - num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + # patch models, imagenet21k (weights from official Google JAX impl) + 'vit_tiny_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_small_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), 'vit_base_patch32_224_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth', - num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - 'vit_large_patch16_224_in21k': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth', - num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', + num_classes=21843), + 'vit_base_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843), 'vit_large_patch32_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', - num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + num_classes=21843), + 'vit_large_patch16_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + num_classes=21843), 'vit_huge_patch14_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', hf_hub='timm/vit_huge_patch14_224_in21k', - num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + num_classes=21843), # deit models (FB weights) 'deit_tiny_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'deit_small_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), + url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'deit_base_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'deit_base_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', - input_size=(3, 384, 384), crop_pct=1.0), + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), 'deit_tiny_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', - classifier=('head', 'head_dist')), + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 'deit_small_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', - classifier=('head', 'head_dist')), + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 'deit_base_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', - classifier=('head', 'head_dist')), + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 'deit_base_distilled_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', - input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, + classifier=('head', 'head_dist')), # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil_in21k': _cfg( @@ -530,12 +540,11 @@ def vit_tiny_patch16_224(pretrained=False, **kwargs): @register_model -def vit_small_patch16_224(pretrained=False, **kwargs): - """ ViT-Small (ViT-S/16) - NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper +def vit_tiny_patch16_384(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) @ 384x384. """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) return model @@ -543,28 +552,37 @@ def vit_small_patch16_224(pretrained=False, **kwargs): def vit_small_patch32_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/32) """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_small_patch16_384(pretrained=False, **kwargs): +def vit_small_patch32_384(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/32) at 384x384. + """ + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_patch16_224(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. +def vit_small_patch16_384(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) return model @@ -577,6 +595,26 @@ def vit_base_patch32_224(pretrained=False, **kwargs): return model +@register_model +def vit_base_patch32_384(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_patch16_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). @@ -588,31 +626,31 @@ def vit_base_patch16_384(pretrained=False, **kwargs): @register_model -def vit_base_patch32_384(pretrained=False, **kwargs): - """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. +def vit_large_patch32_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. """ - model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_large_patch16_224(pretrained=False, **kwargs): +def vit_large_patch32_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_large_patch32_224(pretrained=False, **kwargs): - """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. +def vit_large_patch16_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -627,23 +665,32 @@ def vit_large_patch16_384(pretrained=False, **kwargs): @register_model -def vit_large_patch32_384(pretrained=False, **kwargs): - """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. +def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_patch16_224_in21k(pretrained=False, **kwargs): - """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). +def vit_small_patch32_224_in21k(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/16) + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -659,13 +706,13 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs): @register_model -def vit_large_patch16_224_in21k(pretrained=False, **kwargs): - """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). +def vit_base_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) - model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -680,6 +727,17 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs): return model +@register_model +def vit_large_patch16_224_in21k(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 1bfe6685..a53419a0 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -35,34 +35,51 @@ def _cfg(url='', **kwargs): default_cfgs = { - # hybrid in-1k models (weights ported from official JAX impl where they exist) - 'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), + # hybrid in-1k models (weights from official JAX impl where they exist) + 'vit_tiny_r_s16_p8_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + first_conv='patch_embed.backbone.conv'), 'vit_tiny_r_s16_p8_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), - 'vit_small_r20_s16_p2_224': _cfg(), - 'vit_small_r20_s16_224': _cfg(), - 'vit_small_r26_s32_224': _cfg(), + 'vit_small_r26_s32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz', + ), 'vit_small_r26_s32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_r20_s16_224': _cfg(), 'vit_base_r26_s32_224': _cfg(), 'vit_base_r50_s16_224': _cfg(), 'vit_base_r50_s16_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', input_size=(3, 384, 384), crop_pct=1.0), - 'vit_large_r50_s32_224': _cfg(), - 'vit_large_r50_s32_384': _cfg(), - - # hybrid in-21k models (weights ported from official Google JAX impl where they exist) + 'vit_large_r50_s32_224': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz' + ), + 'vit_large_r50_s32_384': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/' + 'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + input_size=(3, 384, 384), crop_pct=1.0 + ), + + # hybrid in-21k models (weights from official Google JAX impl where they exist) + 'vit_tiny_r_s16_p8_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i1k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'), 'vit_small_r26_s32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz', num_classes=21843, crop_pct=0.9), - 'vit_small_r26_s32_384_in21k': _cfg( - num_classes=21843, input_size=(3, 384, 384), crop_pct=1.0), 'vit_base_r50_s16_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', num_classes=21843, crop_pct=0.9), - 'vit_large_r50_s32_224_in21k': _cfg(num_classes=21843, crop_pct=0.9), + 'vit_large_r50_s32_224_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz', + num_classes=21843, crop_pct=0.9), # hybrid models (using timm resnet backbones) 'vit_small_resnet26d_224': _cfg( @@ -163,51 +180,6 @@ def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): return model -@register_model -def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): - """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384. - """ - backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_r_s16_p8_224(pretrained=False, **kwargs): - """ R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224. - """ - backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_small_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) - - return model - - -@register_model -def vit_small_r20_s16_p2_224(pretrained=False, **kwargs): - """ R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224. - """ - backbone = _resnetv2((2, 4), **kwargs) - model_kwargs = dict(patch_size=2, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_small_r20_s16_p2_224', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_r20_s16_224(pretrained=False, **kwargs): - """ R20+ViT-S/S16 hybrid. - """ - backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_small_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_small_r26_s32_224(pretrained=False, **kwargs): """ R26+ViT-S/S32 hybrid. @@ -230,17 +202,6 @@ def vit_small_r26_s32_384(pretrained=False, **kwargs): return model -@register_model -def vit_base_r20_s16_224(pretrained=False, **kwargs): - """ R20+ViT-B/S16 hybrid. - """ - backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer_hybrid( - 'vit_base_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_base_r26_s32_224(pretrained=False, **kwargs): """ R26+ViT-B/S32 hybrid. @@ -298,24 +259,24 @@ def vit_large_r50_s32_384(pretrained=False, **kwargs): @register_model -def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): - """ R26+ViT-S/S32 hybrid. +def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k. """ - backbone = _resnetv2((2, 2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer_hybrid( - 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + 'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @register_model -def vit_small_r26_s32_384_in21k(pretrained=False, **kwargs): - """ R26+ViT-S/S32 hybrid. +def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. ImageNet-21k. """ backbone = _resnetv2((2, 2, 2, 2), **kwargs) model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer_hybrid( - 'vit_small_r26_s32_384_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -331,6 +292,17 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): return model +@register_model +def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. ImageNet-21k. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_small_resnet26d_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.