From 6d8272e92c3d5f13a9fdd91dfe1eb7fae6784589 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 8 Jul 2021 11:23:55 -0700 Subject: [PATCH] Add SAM pretrained model defs/weights for ViT B16 and B32 models. --- README.md | 3 +++ timm/models/vision_transformer.py | 27 ++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 52acfbb9..a7e10290 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,9 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### July 8, 2021 +* Add [SAM pretrained](https://arxiv.org/abs/2106.01548) in1k weight for ViT B/16 (`vit_base_patch16_sam_224`) and B/32 (`vit_base_patch32_sam_224`) models. + ### July 5, 2021 * Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare). diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 9ec45868..e0c904f7 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -129,6 +129,12 @@ default_cfgs = { hf_hub='timm/vit_huge_patch14_224_in21k', num_classes=21843), + # SAM trained models (https://arxiv.org/abs/2106.01548) + 'vit_base_patch32_sam_224': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), + 'vit_base_patch16_sam_224': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), + # deit models (FB weights) 'deit_tiny_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', @@ -596,7 +602,8 @@ def vit_small_patch16_384(pretrained=False, **kwargs): @register_model def vit_base_patch32_224(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) @@ -672,6 +679,24 @@ def vit_large_patch16_384(pretrained=False, **kwargs): return model +@register_model +def vit_base_patch16_sam_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_patch16_sam_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch32_sam_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 + """ + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_patch32_sam_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16).