Fix ViT in21k representation (pre_logits) layer handling across old and new npz checkpoints

vit_and_bit_test_fixes
Ross Wightman 3 years ago
parent b41cffaa93
commit 85f894e03d

@ -424,7 +424,7 @@ def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/
model.stem.conv.weight.copy_(stem_conv_w)
model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
if isinstance(model.head.fc, nn.Conv2d) and \
if isinstance(getattr(model.head, 'fc', None), nn.Conv2d) and \
model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))

@ -6,7 +6,7 @@ A PyTorch implement of Vision Transformers as described in:
- https://arxiv.org/abs/2010.11929
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- https://arxiv.org/abs/2106.TODO
- https://arxiv.org/abs/2106.10270
The official jax code is released and available at https://github.com/google-research/vision_transformer
@ -451,6 +451,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
@ -673,6 +676,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
@ -683,6 +687,7 @@ def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
@ -693,6 +698,7 @@ def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
@ -703,9 +709,10 @@ def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -714,9 +721,10 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -725,6 +733,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
"""
model_kwargs = dict(
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
@ -736,9 +745,10 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -747,7 +757,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: converted weights not currently available, too large for github release hosting.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)

Loading…
Cancel
Save