diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index bd9ea231..72f3a61a 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -48,7 +48,8 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', ), 'vit_base_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), ), 'vit_base_patch16_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', @@ -56,7 +57,9 @@ default_cfgs = { 'vit_base_patch32_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), - 'vit_large_patch16_224': _cfg(), + 'vit_large_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 'vit_large_patch16_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), @@ -305,10 +308,9 @@ def vit_small_patch16_224(pretrained=False, **kwargs): @register_model def vit_base_patch16_224(pretrained=False, **kwargs): - if pretrained: - # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model - kwargs.setdefault('qk_scale', 768 ** -0.5) - model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model = VisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = default_cfgs['vit_base_patch16_224'] if pretrained: load_pretrained( @@ -340,8 +342,12 @@ def vit_base_patch32_384(pretrained=False, **kwargs): @register_model def vit_large_patch16_224(pretrained=False, **kwargs): - model = VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model = VisionTransformer( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = default_cfgs['vit_large_patch16_224'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model