@ -6,7 +6,7 @@ A PyTorch implement of Vision Transformers as described in:
- https : / / arxiv . org / abs / 2010.11929
- https : / / arxiv . org / abs / 2010.11929
` How to train your ViT ? Data , Augmentation , and Regularization in Vision Transformers `
` How to train your ViT ? Data , Augmentation , and Regularization in Vision Transformers `
- https : / / arxiv . org / abs / 2106. TODO
- https : / / arxiv . org / abs / 2106. 10270
The official jax code is released and available at https : / / github . com / google - research / vision_transformer
The official jax code is released and available at https : / / github . com / google - research / vision_transformer
@ -448,9 +448,12 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
model . pos_embed . copy_ ( pos_embed_w )
model . pos_embed . copy_ ( pos_embed_w )
model . norm . weight . copy_ ( _n2p ( w [ f ' { prefix } Transformer/encoder_norm/scale ' ] ) )
model . norm . weight . copy_ ( _n2p ( w [ f ' { prefix } Transformer/encoder_norm/scale ' ] ) )
model . norm . bias . copy_ ( _n2p ( w [ f ' { prefix } Transformer/encoder_norm/bias ' ] ) )
model . norm . bias . copy_ ( _n2p ( w [ f ' { prefix } Transformer/encoder_norm/bias ' ] ) )
if model . head . bias . shape [ 0 ] == w [ f ' { prefix } head/bias ' ] . shape [ - 1 ] :
if isinstance ( model . head , nn . Linear ) and model . head . bias . shape [ 0 ] == w [ f ' { prefix } head/bias ' ] . shape [ - 1 ] :
model . head . weight . copy_ ( _n2p ( w [ f ' { prefix } head/kernel ' ] ) )
model . head . weight . copy_ ( _n2p ( w [ f ' { prefix } head/kernel ' ] ) )
model . head . bias . copy_ ( _n2p ( w [ f ' { prefix } head/bias ' ] ) )
model . head . bias . copy_ ( _n2p ( w [ f ' { prefix } head/bias ' ] ) )
if isinstance ( getattr ( model . pre_logits , ' fc ' , None ) , nn . Linear ) and f ' { prefix } pre_logits/bias ' in w :
model . pre_logits . fc . weight . copy_ ( _n2p ( w [ f ' { prefix } pre_logits/kernel ' ] ) )
model . pre_logits . fc . bias . copy_ ( _n2p ( w [ f ' { prefix } pre_logits/bias ' ] ) )
for i , block in enumerate ( model . blocks . children ( ) ) :
for i , block in enumerate ( model . blocks . children ( ) ) :
block_prefix = f ' { prefix } Transformer/encoderblock_ { i } / '
block_prefix = f ' { prefix } Transformer/encoderblock_ { i } / '
mha_prefix = block_prefix + ' MultiHeadDotProductAttention_1/ '
mha_prefix = block_prefix + ' MultiHeadDotProductAttention_1/ '
@ -673,6 +676,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
def vit_tiny_patch16_224_in21k ( pretrained = False , * * kwargs ) :
def vit_tiny_patch16_224_in21k ( pretrained = False , * * kwargs ) :
""" ViT-Tiny (Vit-Ti/16).
""" ViT-Tiny (Vit-Ti/16).
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
NOTE : this model has valid 21 k classifier head and no representation ( pre - logits ) layer
"""
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , * * kwargs )
model_kwargs = dict ( patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , * * kwargs )
model = _create_vision_transformer ( ' vit_tiny_patch16_224_in21k ' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' vit_tiny_patch16_224_in21k ' , pretrained = pretrained , * * model_kwargs )
@ -683,6 +687,7 @@ def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
def vit_small_patch32_224_in21k ( pretrained = False , * * kwargs ) :
def vit_small_patch32_224_in21k ( pretrained = False , * * kwargs ) :
""" ViT-Small (ViT-S/16)
""" ViT-Small (ViT-S/16)
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
NOTE : this model has valid 21 k classifier head and no representation ( pre - logits ) layer
"""
"""
model_kwargs = dict ( patch_size = 32 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model_kwargs = dict ( patch_size = 32 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer ( ' vit_small_patch32_224_in21k ' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' vit_small_patch32_224_in21k ' , pretrained = pretrained , * * model_kwargs )
@ -693,6 +698,7 @@ def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
def vit_small_patch16_224_in21k ( pretrained = False , * * kwargs ) :
def vit_small_patch16_224_in21k ( pretrained = False , * * kwargs ) :
""" ViT-Small (ViT-S/16)
""" ViT-Small (ViT-S/16)
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
NOTE : this model has valid 21 k classifier head and no representation ( pre - logits ) layer
"""
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model_kwargs = dict ( patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer ( ' vit_small_patch16_224_in21k ' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' vit_small_patch16_224_in21k ' , pretrained = pretrained , * * model_kwargs )
@ -703,9 +709,10 @@ def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
def vit_base_patch32_224_in21k ( pretrained = False , * * kwargs ) :
def vit_base_patch32_224_in21k ( pretrained = False , * * kwargs ) :
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
NOTE : this model has valid 21 k classifier head and no representation ( pre - logits ) layer
"""
"""
model_kwargs = dict (
model_kwargs = dict (
patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , representation_size = 768 , * * kwargs )
patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer ( ' vit_base_patch32_224_in21k ' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' vit_base_patch32_224_in21k ' , pretrained = pretrained , * * model_kwargs )
return model
return model
@ -714,9 +721,10 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
def vit_base_patch16_224_in21k ( pretrained = False , * * kwargs ) :
def vit_base_patch16_224_in21k ( pretrained = False , * * kwargs ) :
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
NOTE : this model has valid 21 k classifier head and no representation ( pre - logits ) layer
"""
"""
model_kwargs = dict (
model_kwargs = dict (
patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , representation_size = 768 , * * kwargs )
patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer ( ' vit_base_patch16_224_in21k ' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' vit_base_patch16_224_in21k ' , pretrained = pretrained , * * model_kwargs )
return model
return model
@ -725,6 +733,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
def vit_large_patch32_224_in21k ( pretrained = False , * * kwargs ) :
def vit_large_patch32_224_in21k ( pretrained = False , * * kwargs ) :
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
NOTE : this model has a representation layer but the 21 k classifier head is zero ' d out in original weights
"""
"""
model_kwargs = dict (
model_kwargs = dict (
patch_size = 32 , embed_dim = 1024 , depth = 24 , num_heads = 16 , representation_size = 1024 , * * kwargs )
patch_size = 32 , embed_dim = 1024 , depth = 24 , num_heads = 16 , representation_size = 1024 , * * kwargs )
@ -736,9 +745,10 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
def vit_large_patch16_224_in21k ( pretrained = False , * * kwargs ) :
def vit_large_patch16_224_in21k ( pretrained = False , * * kwargs ) :
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
NOTE : this model has valid 21 k classifier head and no representation ( pre - logits ) layer
"""
"""
model_kwargs = dict (
model_kwargs = dict (
patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , representation_size = 1024 , * * kwargs )
patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , * * kwargs )
model = _create_vision_transformer ( ' vit_large_patch16_224_in21k ' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' vit_large_patch16_224_in21k ' , pretrained = pretrained , * * model_kwargs )
return model
return model
@ -747,7 +757,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
def vit_huge_patch14_224_in21k ( pretrained = False , * * kwargs ) :
def vit_huge_patch14_224_in21k ( pretrained = False , * * kwargs ) :
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
NOTE : converted weights not currently available , too large for github release hosting .
NOTE : this model has a representation layer but the 21 k classifier head is zero ' d out in original weights
"""
"""
model_kwargs = dict (
model_kwargs = dict (
patch_size = 14 , embed_dim = 1280 , depth = 32 , num_heads = 16 , representation_size = 1280 , * * kwargs )
patch_size = 14 , embed_dim = 1280 , depth = 32 , num_heads = 16 , representation_size = 1280 , * * kwargs )