|
|
@ -6,7 +6,7 @@ A PyTorch implement of Vision Transformers as described in:
|
|
|
|
- https://arxiv.org/abs/2010.11929
|
|
|
|
- https://arxiv.org/abs/2010.11929
|
|
|
|
|
|
|
|
|
|
|
|
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
|
|
|
|
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
|
|
|
|
- https://arxiv.org/abs/2106.TODO
|
|
|
|
- https://arxiv.org/abs/2106.10270
|
|
|
|
|
|
|
|
|
|
|
|
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
|
|
|
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
|
|
|
|
|
|
|
|
|
|
@ -451,6 +451,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|
|
|
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
|
|
|
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
|
|
|
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
|
|
|
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
|
|
|
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
|
|
|
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
|
|
|
|
|
|
|
if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
|
|
|
|
|
|
|
model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
|
|
|
|
|
|
|
model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
|
|
|
for i, block in enumerate(model.blocks.children()):
|
|
|
|
for i, block in enumerate(model.blocks.children()):
|
|
|
|
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
|
|
|
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
|
|
|
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
|
|
|
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
|
|
@ -673,6 +676,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
|
|
|
|
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Tiny (Vit-Ti/16).
|
|
|
|
""" ViT-Tiny (Vit-Ti/16).
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
|
|
|
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
@ -683,6 +687,7 @@ def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Small (ViT-S/16)
|
|
|
|
""" ViT-Small (ViT-S/16)
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
|
|
|
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
@ -693,6 +698,7 @@ def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Small (ViT-S/16)
|
|
|
|
""" ViT-Small (ViT-S/16)
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
|
|
|
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
@ -703,9 +709,10 @@ def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
|
|
|
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
|
|
|
patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -714,9 +721,10 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
|
|
|
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -725,6 +733,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
|
|
|
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
|
|
|
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
|
|
@ -736,9 +745,10 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
|
|
|
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -747,7 +757,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
NOTE: converted weights not currently available, too large for github release hosting.
|
|
|
|
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
|
|
|
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
|
|
|
|