From 708d87a813cf208b2f103b92c2e8029c8062cdd9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 27 Aug 2021 09:20:13 -0700 Subject: [PATCH] Fix ViT SAM weight compat as weights at URL changed to not use repr layer. Fix #825. Tweak optim test. --- tests/test_optim.py | 2 +- timm/models/vision_transformer.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index c12e33cc..a0fe994e 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -320,7 +320,7 @@ def test_sgd(optimizer): lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1) ) _test_basic_cases( - lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1, weight_decay=1) + lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1, weight_decay=.1) ) _test_rosenbrock( lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index e3bcb6fe..de8248fe 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -683,7 +683,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs): def vit_base_patch16_sam_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + # NOTE original SAM weights releaes worked with representation_size=768 + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs) model = _create_vision_transformer('vit_base_patch16_sam_224', pretrained=pretrained, **model_kwargs) return model @@ -692,7 +693,8 @@ def vit_base_patch16_sam_224(pretrained=False, **kwargs): def vit_base_patch32_sam_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 """ - model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + # NOTE original SAM weights releaes worked with representation_size=768 + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs) model = _create_vision_transformer('vit_base_patch32_sam_224', pretrained=pretrained, **model_kwargs) return model