Add small vision transformer weights. 77.42 top-1.

pull/268/head
Ross Wightman 4 years ago
parent ccfb5751ab
commit d4db9e7977

@ -29,7 +29,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .helpers import load_pretrained
from .layers import DropPath, to_2tuple, trunc_normal_
from .resnet import resnet26d, resnet50d
from .registry import register_model
@ -48,7 +48,9 @@ def _cfg(url='', **kwargs):
default_cfgs = {
# patch models
'vit_small_patch16_224': _cfg(),
'vit_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
),
'vit_base_patch16_224': _cfg(),
'vit_base_patch16_384': _cfg(input_size=(3, 384, 384)),
'vit_base_patch32_384': _cfg(input_size=(3, 384, 384)),
@ -271,6 +273,9 @@ class VisionTransformer(nn.Module):
def vit_small_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs)
model.default_cfg = default_cfgs['vit_small_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
return model

Loading…
Cancel
Save