diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 65acacab..7e41f8a3 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -109,6 +109,8 @@ default_cfgs = { 'vit_giant_patch14_224': _cfg(url=''), 'vit_gigantic_patch14_224': _cfg(url=''), + 'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), + # patch models, imagenet21k (weights from official Google JAX impl) 'vit_tiny_patch16_224_in21k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', @@ -202,6 +204,7 @@ default_cfgs = { class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 @@ -634,6 +637,16 @@ def vit_base_patch32_224(pretrained=False, **kwargs): return model +@register_model +def vit_base2_patch32_256(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32) + # FIXME experiment + """ + model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, **kwargs) + model = _create_vision_transformer('vit_base2_patch32_256', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_patch32_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).