diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b9857ed2..57380deb 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -29,7 +29,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg +from .helpers import load_pretrained from .layers import DropPath, to_2tuple, trunc_normal_ from .resnet import resnet26d, resnet50d from .registry import register_model @@ -48,7 +48,9 @@ def _cfg(url='', **kwargs): default_cfgs = { # patch models - 'vit_small_patch16_224': _cfg(), + 'vit_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', + ), 'vit_base_patch16_224': _cfg(), 'vit_base_patch16_384': _cfg(input_size=(3, 384, 384)), 'vit_base_patch32_384': _cfg(input_size=(3, 384, 384)), @@ -271,6 +273,9 @@ class VisionTransformer(nn.Module): def vit_small_patch16_224(pretrained=False, **kwargs): model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs) model.default_cfg = default_cfgs['vit_small_patch16_224'] + if pretrained: + load_pretrained( + model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3)) return model