Update ViT weights, more details to be added before merge.

cleanup_xla_model_fixes
Ross Wightman 4 years ago
parent 8257b86550
commit b319eb5b5d

@ -27,7 +27,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from .registry import register_model
@ -40,106 +40,116 @@ def _cfg(url='', **kwargs):
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
# FIXME weights coming
# patch models (weights from official Google JAX impl)
'vit_tiny_patch16_224': _cfg(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
url='https://storage.googleapis.com/vit_models/augreg/'
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_tiny_patch16_384': _cfg(
url='',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0
),
'vit_small_patch16_224': _cfg(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
url='https://storage.googleapis.com/vit_models/augreg/'
'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch32_224': _cfg(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
'vit_small_patch16_384': _cfg(
url='',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0
),
url='https://storage.googleapis.com/vit_models/augreg/'
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_small_patch32_384': _cfg(
url='',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0
),
# patch models (weights ported from official Google JAX impl)
'vit_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
url='https://storage.googleapis.com/vit_models/augreg/'
'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_small_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_base_patch16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
url='https://storage.googleapis.com/vit_models/augreg/'
'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
'vit_base_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
'vit_large_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
url='https://storage.googleapis.com/vit_models/augreg/'
'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_base_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch32_224': _cfg(
url='', # no official model weights for this combo, only for in21k
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
),
'vit_large_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_patch16_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
'vit_large_patch16_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
# patch models, imagenet21k (weights ported from official Google JAX impl)
'vit_base_patch16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# patch models, imagenet21k (weights from official Google JAX impl)
'vit_tiny_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_small_patch32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_small_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_base_patch32_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_base_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843),
'vit_large_patch32_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
num_classes=21843),
'vit_large_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
num_classes=21843),
'vit_huge_patch14_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub='timm/vit_huge_patch14_224_in21k',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
num_classes=21843),
# deit models (FB weights)
'deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'deit_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
input_size=(3, 384, 384), crop_pct=1.0),
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
'deit_tiny_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
classifier=('head', 'head_dist')),
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_small_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
classifier=('head', 'head_dist')),
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
classifier=('head', 'head_dist')),
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
'deit_base_distilled_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
classifier=('head', 'head_dist')),
# ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil_in21k': _cfg(
@ -530,12 +540,11 @@ def vit_tiny_patch16_224(pretrained=False, **kwargs):
@register_model
def vit_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
def vit_tiny_patch16_384(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16) @ 384x384.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@ -543,28 +552,37 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
def vit_small_patch32_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32)
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_384(pretrained=False, **kwargs):
def vit_small_patch32_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32) at 384x384.
"""
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
def vit_small_patch16_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@ -577,6 +595,26 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
return model
@register_model
def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
@ -588,31 +626,31 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
@register_model
def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
def vit_large_patch32_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch16_224(pretrained=False, **kwargs):
def vit_large_patch32_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_patch32_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
def vit_large_patch16_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@ -627,23 +665,32 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
@register_model
def vit_large_patch32_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -659,13 +706,13 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
@register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -680,6 +727,17 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
return model
@register_model
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).

@ -35,34 +35,51 @@ def _cfg(url='', **kwargs):
default_cfgs = {
# hybrid in-1k models (weights ported from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'),
# hybrid in-1k models (weights from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'),
'vit_small_r20_s16_p2_224': _cfg(),
'vit_small_r20_s16_224': _cfg(),
'vit_small_r26_s32_224': _cfg(),
'vit_small_r26_s32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
),
'vit_small_r26_s32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_r20_s16_224': _cfg(),
'vit_base_r26_s32_224': _cfg(),
'vit_base_r50_s16_224': _cfg(),
'vit_base_r50_s16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224': _cfg(),
'vit_large_r50_s32_384': _cfg(),
# hybrid in-21k models (weights ported from official Google JAX impl where they exist)
'vit_large_r50_s32_224': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'
),
'vit_large_r50_s32_384': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/'
'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0
),
# hybrid in-21k models (weights from official Google JAX impl where they exist)
'vit_tiny_r_s16_p8_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i1k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'),
'vit_small_r26_s32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9),
'vit_small_r26_s32_384_in21k': _cfg(
num_classes=21843, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_r50_s16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, crop_pct=0.9),
'vit_large_r50_s32_224_in21k': _cfg(num_classes=21843, crop_pct=0.9),
'vit_large_r50_s32_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9),
# hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224': _cfg(
@ -163,51 +180,6 @@ def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
return model
@register_model
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2((2, 4), **kwargs)
model_kwargs = dict(patch_size=2, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r20_s16_p2_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r20_s16_224(pretrained=False, **kwargs):
""" R20+ViT-S/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_224(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid.
@ -230,17 +202,6 @@ def vit_small_r26_s32_384(pretrained=False, **kwargs):
return model
@register_model
def vit_base_r20_s16_224(pretrained=False, **kwargs):
""" R20+ViT-B/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r26_s32_224(pretrained=False, **kwargs):
""" R26+ViT-B/S32 hybrid.
@ -298,24 +259,24 @@ def vit_large_r50_s32_384(pretrained=False, **kwargs):
@register_model
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid.
def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_384_in21k(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid.
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid. ImageNet-21k.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_384_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@ -331,6 +292,17 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
return model
@register_model
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid. ImageNet-21k.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_resnet26d_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.

Loading…
Cancel
Save