Update ViT weights, more details to be added before merge.

cleanup_xla_model_fixes
Ross Wightman 3 years ago
parent 8257b86550
commit b319eb5b5d

@ -27,7 +27,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from .registry import register_model from .registry import register_model
@ -40,106 +40,116 @@ def _cfg(url='', **kwargs):
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs **kwargs
} }
default_cfgs = { default_cfgs = {
# FIXME weights coming # patch models (weights from official Google JAX impl)
'vit_tiny_patch16_224': _cfg( 'vit_tiny_patch16_224': _cfg(
url='', url='https://storage.googleapis.com/vit_models/augreg/'
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
),
'vit_tiny_patch16_384': _cfg( 'vit_tiny_patch16_384': _cfg(
url='', url='https://storage.googleapis.com/vit_models/augreg/'
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
), input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch16_224': _cfg(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
'vit_small_patch32_224': _cfg( 'vit_small_patch32_224': _cfg(
url='', url='https://storage.googleapis.com/vit_models/augreg/'
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
),
'vit_small_patch16_384': _cfg(
url='',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0
),
'vit_small_patch32_384': _cfg( 'vit_small_patch32_384': _cfg(
url='', url='https://storage.googleapis.com/vit_models/augreg/'
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
), input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch16_224': _cfg(
# patch models (weights ported from official Google JAX impl) url='https://storage.googleapis.com/vit_models/augreg/'
'vit_base_patch16_224': _cfg( 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', 'vit_small_patch16_384': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), url='https://storage.googleapis.com/vit_models/augreg/'
), 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch32_224': _cfg( 'vit_base_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k url='https://storage.googleapis.com/vit_models/augreg/'
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_base_patch16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
'vit_base_patch32_384': _cfg( 'vit_base_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', url='https://storage.googleapis.com/vit_models/augreg/'
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
'vit_large_patch16_224': _cfg( input_size=(3, 384, 384), crop_pct=1.0),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', 'vit_base_patch16_224': _cfg(
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_base_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch32_224': _cfg( 'vit_large_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k url='', # no official model weights for this combo, only for in21k
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), ),
'vit_large_patch16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
'vit_large_patch32_384': _cfg( 'vit_large_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_large_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
# patch models, imagenet21k (weights ported from official Google JAX impl) # patch models, imagenet21k (weights from official Google JAX impl)
'vit_base_patch16_224_in21k': _cfg( 'vit_tiny_patch16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth', url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843),
'vit_small_patch32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_small_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_base_patch32_224_in21k': _cfg( 'vit_base_patch32_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth', url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843),
'vit_large_patch16_224_in21k': _cfg( 'vit_base_patch16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth', url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843),
'vit_large_patch32_224_in21k': _cfg( 'vit_large_patch32_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843),
'vit_large_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
num_classes=21843),
'vit_huge_patch14_224_in21k': _cfg( 'vit_huge_patch14_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub='timm/vit_huge_patch14_224_in21k', hf_hub='timm/vit_huge_patch14_224_in21k',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843),
# deit models (FB weights) # deit models (FB weights)
'deit_tiny_patch16_224': _cfg( 'deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_small_patch16_224': _cfg( 'deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_base_patch16_224': _cfg( 'deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_base_patch16_384': _cfg( 'deit_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
input_size=(3, 384, 384), crop_pct=1.0), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
'deit_tiny_distilled_patch16_224': _cfg( 'deit_tiny_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
classifier=('head', 'head_dist')), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_small_distilled_patch16_224': _cfg( 'deit_small_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
classifier=('head', 'head_dist')), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_224': _cfg( 'deit_base_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
classifier=('head', 'head_dist')), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_384': _cfg( 'deit_base_distilled_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
classifier=('head', 'head_dist')),
# ViT ImageNet-21K-P pretraining by MILL # ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil_in21k': _cfg( 'vit_base_patch16_224_miil_in21k': _cfg(
@ -530,12 +540,11 @@ def vit_tiny_patch16_224(pretrained=False, **kwargs):
@register_model @register_model
def vit_small_patch16_224(pretrained=False, **kwargs): def vit_tiny_patch16_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16) """ ViT-Tiny (Vit-Ti/16) @ 384x384.
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
return model return model
@ -543,28 +552,37 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
def vit_small_patch32_224(pretrained=False, **kwargs): def vit_small_patch32_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32) """ ViT-Small (ViT-S/32)
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_small_patch16_384(pretrained=False, **kwargs): def vit_small_patch32_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32) at 384x384.
"""
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16) """ ViT-Small (ViT-S/16)
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_patch16_224(pretrained=False, **kwargs): def vit_small_patch16_384(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Small (ViT-S/16)
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
return model return model
@ -577,6 +595,26 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_base_patch16_384(pretrained=False, **kwargs): def vit_base_patch16_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
@ -588,31 +626,31 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
@register_model @register_model
def vit_base_patch32_384(pretrained=False, **kwargs): def vit_large_patch32_224(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_large_patch16_224(pretrained=False, **kwargs): def vit_large_patch32_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_large_patch32_224(pretrained=False, **kwargs): def vit_large_patch16_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model return model
@ -627,23 +665,32 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
@register_model @register_model
def vit_large_patch32_384(pretrained=False, **kwargs): def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Tiny (Vit-Ti/16).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_patch16_224_in21k(pretrained=False, **kwargs): def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model
@register_model
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@ -659,13 +706,13 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
@register_model @register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs): def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@ -680,6 +727,17 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
return model return model
@register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).

@ -35,34 +35,51 @@ def _cfg(url='', **kwargs):
default_cfgs = { default_cfgs = {
# hybrid in-1k models (weights ported from official JAX impl where they exist) # hybrid in-1k models (weights from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), 'vit_tiny_r_s16_p8_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384': _cfg( 'vit_tiny_r_s16_p8_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0), first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), 'vit_small_r26_s32_224': _cfg(
'vit_small_r20_s16_p2_224': _cfg(), url='https://storage.googleapis.com/vit_models/augreg/'
'vit_small_r20_s16_224': _cfg(), 'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
'vit_small_r26_s32_224': _cfg(), ),
'vit_small_r26_s32_384': _cfg( 'vit_small_r26_s32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_r20_s16_224': _cfg(),
'vit_base_r26_s32_224': _cfg(), 'vit_base_r26_s32_224': _cfg(),
'vit_base_r50_s16_224': _cfg(), 'vit_base_r50_s16_224': _cfg(),
'vit_base_r50_s16_384': _cfg( 'vit_base_r50_s16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224': _cfg(), 'vit_large_r50_s32_224': _cfg(
'vit_large_r50_s32_384': _cfg(), url='https://storage.googleapis.com/vit_models/augreg/'
'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'
# hybrid in-21k models (weights ported from official Google JAX impl where they exist) ),
'vit_large_r50_s32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0
),
# hybrid in-21k models (weights from official Google JAX impl where they exist)
'vit_tiny_r_s16_p8_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i1k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'),
'vit_small_r26_s32_224_in21k': _cfg( 'vit_small_r26_s32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9), num_classes=21843, crop_pct=0.9),
'vit_small_r26_s32_384_in21k': _cfg(
num_classes=21843, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_r50_s16_224_in21k': _cfg( 'vit_base_r50_s16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, crop_pct=0.9), num_classes=21843, crop_pct=0.9),
'vit_large_r50_s32_224_in21k': _cfg(num_classes=21843, crop_pct=0.9), 'vit_large_r50_s32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9),
# hybrid models (using timm resnet backbones) # hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224': _cfg( 'vit_small_resnet26d_224': _cfg(
@ -163,51 +180,6 @@ def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
return model return model
@register_model
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2((2, 4), **kwargs)
model_kwargs = dict(patch_size=2, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r20_s16_p2_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r20_s16_224(pretrained=False, **kwargs):
""" R20+ViT-S/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_small_r26_s32_224(pretrained=False, **kwargs): def vit_small_r26_s32_224(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid. """ R26+ViT-S/S32 hybrid.
@ -230,17 +202,6 @@ def vit_small_r26_s32_384(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base_r20_s16_224(pretrained=False, **kwargs):
""" R20+ViT-B/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_base_r26_s32_224(pretrained=False, **kwargs): def vit_base_r26_s32_224(pretrained=False, **kwargs):
""" R26+ViT-B/S32 hybrid. """ R26+ViT-B/S32 hybrid.
@ -298,24 +259,24 @@ def vit_large_r50_s32_384(pretrained=False, **kwargs):
@register_model @register_model
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs): def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid. """ R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k.
""" """
backbone = _resnetv2((2, 2, 2, 2), **kwargs) backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid( model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) 'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_small_r26_s32_384_in21k(pretrained=False, **kwargs): def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid. """ R26+ViT-S/S32 hybrid. ImageNet-21k.
""" """
backbone = _resnetv2((2, 2, 2, 2), **kwargs) backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid( model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_384_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model return model
@ -331,6 +292,17 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
return model return model
@register_model
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid. ImageNet-21k.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_small_resnet26d_224(pretrained=False, **kwargs): def vit_small_resnet26d_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.

Loading…
Cancel
Save