diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 8110fcca..b96d7742 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -424,7 +424,7 @@ def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/ model.stem.conv.weight.copy_(stem_conv_w) model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma'])) model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta'])) - if isinstance(model.head.fc, nn.Conv2d) and \ + if isinstance(getattr(model.head, 'fc', None), nn.Conv2d) and \ model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel'])) model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias'])) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 0a960987..9ec45868 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -6,7 +6,7 @@ A PyTorch implement of Vision Transformers as described in: - https://arxiv.org/abs/2010.11929 `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` - - https://arxiv.org/abs/2106.TODO + - https://arxiv.org/abs/2106.10270 The official jax code is released and available at https://github.com/google-research/vision_transformer @@ -451,6 +451,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' @@ -673,6 +676,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs): def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) @@ -683,6 +687,7 @@ def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): def vit_small_patch32_224_in21k(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) @@ -693,6 +698,7 @@ def vit_small_patch32_224_in21k(pretrained=False, **kwargs): def vit_small_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) @@ -703,9 +709,10 @@ def vit_small_patch16_224_in21k(pretrained=False, **kwargs): def vit_base_patch32_224_in21k(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -714,9 +721,10 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs): def vit_base_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -725,6 +733,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs): def vit_large_patch32_224_in21k(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights """ model_kwargs = dict( patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) @@ -736,9 +745,10 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs): def vit_large_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. + NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) + patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -747,7 +757,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs): def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. - NOTE: converted weights not currently available, too large for github release hosting. + NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights """ model_kwargs = dict( patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)