@ -5,12 +5,6 @@ A PyTorch implement of Vision Transformers as described in
The official jax code is released and available at https : / / github . com / google - research / vision_transformer
The official jax code is released and available at https : / / github . com / google - research / vision_transformer
Status / TODO :
* Models updated to be compatible with official impl . Args added to support backward compat for old PyTorch weights .
* Weights ported from official jax impl for 384 x384 base and small models , 16 x16 and 32 x32 patches .
* Trained ( supervised on ImageNet - 1 k ) my custom ' small ' patch model to 77.9 , ' base ' to 79.4 top - 1 with this code .
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w / ImageNet fine - tune in future .
Acknowledgments :
Acknowledgments :
* The paper authors for releasing code and weights , thanks !
* The paper authors for releasing code and weights , thanks !
* I fixed my class token impl based on Phil Wang ' s https://github.com/lucidrains/vit-pytorch ... check it out
* I fixed my class token impl based on Phil Wang ' s https://github.com/lucidrains/vit-pytorch ... check it out
@ -18,6 +12,9 @@ for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy ' s https://github.com/karpathy/minGPT
* Simple transformer style inspired by Andrej Karpathy ' s https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
DeiT model defs and weights from https : / / github . com / facebookresearch / deit ,
paper ` DeiT : Data - efficient Image Transformers ` - https : / / arxiv . org / abs / 2012.12877
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2020 Ross Wightman
"""
"""
import torch
import torch
@ -50,7 +47,7 @@ default_cfgs = {
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth ' ,
) ,
) ,
# patch models (weights ported from official JAX impl)
# patch models (weights ported from official Google JAX impl)
' vit_base_patch16_224 ' : _cfg (
' vit_base_patch16_224 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth ' ,
mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ,
mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ,
@ -77,7 +74,7 @@ default_cfgs = {
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth ' ,
input_size = ( 3 , 384 , 384 ) , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) , crop_pct = 1.0 ) ,
input_size = ( 3 , 384 , 384 ) , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) , crop_pct = 1.0 ) ,
# patch models, imagenet21k (weights ported from official JAX impl)
# patch models, imagenet21k (weights ported from official Google JAX impl)
' vit_base_patch16_224_in21k ' : _cfg (
' vit_base_patch16_224_in21k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth ' ,
num_classes = 21843 , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
num_classes = 21843 , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
@ -94,7 +91,7 @@ default_cfgs = {
url = ' ' , # FIXME I have weights for this but > 2GB limit for github release binaries
url = ' ' , # FIXME I have weights for this but > 2GB limit for github release binaries
num_classes = 21843 , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
num_classes = 21843 , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
# hybrid models (weights ported from official JAX impl)
# hybrid models (weights ported from official Google JAX impl)
' vit_base_resnet50_224_in21k ' : _cfg (
' vit_base_resnet50_224_in21k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth ' ,
mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) , crop_pct = 0.9 ) ,
mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) , crop_pct = 0.9 ) ,
@ -107,6 +104,17 @@ default_cfgs = {
' vit_small_resnet50d_s3_224 ' : _cfg ( ) ,
' vit_small_resnet50d_s3_224 ' : _cfg ( ) ,
' vit_base_resnet26d_224 ' : _cfg ( ) ,
' vit_base_resnet26d_224 ' : _cfg ( ) ,
' vit_base_resnet50d_224 ' : _cfg ( ) ,
' vit_base_resnet50d_224 ' : _cfg ( ) ,
# deit models (FB weights)
' deit_tiny_patch16_224 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth ' ) ,
' deit_small_patch16_224 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth ' ) ,
' deit_base_patch16_224 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth ' , ) ,
' deit_base_patch16_384 ' : _cfg (
url = ' ' , # no weights yet
input_size = ( 3 , 384 , 384 ) ) ,
}
}
@ -433,7 +441,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
@register_model
@register_model
def vit_base_patch16_224_in21k ( pretrained = False , * * kwargs ) :
def vit_base_patch16_224_in21k ( pretrained = False , * * kwargs ) :
num_classes = kwargs . get ( ' num_classes ' , 21843 )
num_classes = kwargs . pop ( ' num_classes ' , 21843 )
model = VisionTransformer (
model = VisionTransformer (
patch_size = 16 , num_classes = num_classes , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , qkv_bias = True ,
patch_size = 16 , num_classes = num_classes , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , qkv_bias = True ,
representation_size = 768 , norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
representation_size = 768 , norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
@ -446,7 +454,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
@register_model
@register_model
def vit_base_patch32_224_in21k ( pretrained = False , * * kwargs ) :
def vit_base_patch32_224_in21k ( pretrained = False , * * kwargs ) :
num_classes = kwargs . get ( ' num_classes ' , 21843 )
num_classes = kwargs . pop ( ' num_classes ' , 21843 )
model = VisionTransformer (
model = VisionTransformer (
img_size = 224 , num_classes = num_classes , patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 ,
img_size = 224 , num_classes = num_classes , patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 ,
qkv_bias = True , representation_size = 768 , norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
qkv_bias = True , representation_size = 768 , norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
@ -458,7 +466,7 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
@register_model
@register_model
def vit_large_patch16_224_in21k ( pretrained = False , * * kwargs ) :
def vit_large_patch16_224_in21k ( pretrained = False , * * kwargs ) :
num_classes = kwargs . get ( ' num_classes ' , 21843 )
num_classes = kwargs . pop ( ' num_classes ' , 21843 )
model = VisionTransformer (
model = VisionTransformer (
patch_size = 16 , num_classes = num_classes , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
patch_size = 16 , num_classes = num_classes , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
representation_size = 1024 , norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
representation_size = 1024 , norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
@ -482,7 +490,7 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
@register_model
@register_model
def vit_huge_patch14_224_in21k ( pretrained = False , * * kwargs ) :
def vit_huge_patch14_224_in21k ( pretrained = False , * * kwargs ) :
num_classes = kwargs . get ( ' num_classes ' , 21843 )
num_classes = kwargs . pop ( ' num_classes ' , 21843 )
model = VisionTransformer (
model = VisionTransformer (
img_size = 224 , patch_size = 14 , num_classes = num_classes , embed_dim = 1280 , depth = 32 , num_heads = 16 , mlp_ratio = 4 ,
img_size = 224 , patch_size = 14 , num_classes = num_classes , embed_dim = 1280 , depth = 32 , num_heads = 16 , mlp_ratio = 4 ,
qkv_bias = True , representation_size = 1280 , norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
qkv_bias = True , representation_size = 1280 , norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
@ -495,7 +503,7 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
@register_model
@register_model
def vit_base_resnet50_224_in21k ( pretrained = False , * * kwargs ) :
def vit_base_resnet50_224_in21k ( pretrained = False , * * kwargs ) :
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
num_classes = kwargs . get ( ' num_classes ' , 21843 )
num_classes = kwargs . pop ( ' num_classes ' , 21843 )
backbone = ResNetV2 (
backbone = ResNetV2 (
layers = ( 3 , 4 , 9 ) , preact = False , stem_type = ' same ' , conv_layer = StdConv2dSame , num_classes = 0 , global_pool = ' ' )
layers = ( 3 , 4 , 9 ) , preact = False , stem_type = ' same ' , conv_layer = StdConv2dSame , num_classes = 0 , global_pool = ' ' )
model = VisionTransformer (
model = VisionTransformer (
@ -559,3 +567,51 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs):
img_size = 224 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , hybrid_backbone = backbone , * * kwargs )
img_size = 224 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , hybrid_backbone = backbone , * * kwargs )
model . default_cfg = default_cfgs [ ' vit_base_resnet50d_224 ' ]
model . default_cfg = default_cfgs [ ' vit_base_resnet50d_224 ' ]
return model
return model
@register_model
def deit_tiny_patch16_224 ( pretrained = False , * * kwargs ) :
model = VisionTransformer (
patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , mlp_ratio = 4 , qkv_bias = True ,
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
model . default_cfg = default_cfgs [ ' deit_tiny_patch16_224 ' ]
if pretrained :
load_pretrained (
model , num_classes = model . num_classes , in_chans = kwargs . get ( ' in_chans ' , 3 ) , filter_fn = lambda x : x [ ' model ' ] )
return model
@register_model
def deit_small_patch16_224 ( pretrained = False , * * kwargs ) :
model = VisionTransformer (
patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , mlp_ratio = 4 , qkv_bias = True ,
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
model . default_cfg = default_cfgs [ ' deit_small_patch16_224 ' ]
if pretrained :
load_pretrained (
model , num_classes = model . num_classes , in_chans = kwargs . get ( ' in_chans ' , 3 ) , filter_fn = lambda x : x [ ' model ' ] )
return model
@register_model
def deit_base_patch16_224 ( pretrained = False , * * kwargs ) :
model = VisionTransformer (
patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , qkv_bias = True ,
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
model . default_cfg = default_cfgs [ ' deit_base_patch16_224 ' ]
if pretrained :
load_pretrained (
model , num_classes = model . num_classes , in_chans = kwargs . get ( ' in_chans ' , 3 ) , filter_fn = lambda x : x [ ' model ' ] )
return model
@register_model
def deit_base_patch16_384 ( pretrained = False , * * kwargs ) :
model = VisionTransformer (
img_size = 384 , patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , qkv_bias = True ,
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
model . default_cfg = default_cfgs [ ' deit_base_patch16_384 ' ]
if pretrained :
load_pretrained (
model , num_classes = model . num_classes , in_chans = kwargs . get ( ' in_chans ' , 3 ) , filter_fn = lambda x : x [ ' model ' ] )
return model