|
|
|
@ -29,7 +29,7 @@ import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
|
from .helpers import load_pretrained
|
|
|
|
|
from .layers import DropPath, to_2tuple, trunc_normal_
|
|
|
|
|
from .resnet import resnet26d, resnet50d
|
|
|
|
|
from .registry import register_model
|
|
|
|
@ -48,7 +48,9 @@ def _cfg(url='', **kwargs):
|
|
|
|
|
|
|
|
|
|
default_cfgs = {
|
|
|
|
|
# patch models
|
|
|
|
|
'vit_small_patch16_224': _cfg(),
|
|
|
|
|
'vit_small_patch16_224': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
|
|
|
|
),
|
|
|
|
|
'vit_base_patch16_224': _cfg(),
|
|
|
|
|
'vit_base_patch16_384': _cfg(input_size=(3, 384, 384)),
|
|
|
|
|
'vit_base_patch32_384': _cfg(input_size=(3, 384, 384)),
|
|
|
|
@ -271,6 +273,9 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
def vit_small_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs)
|
|
|
|
|
model.default_cfg = default_cfgs['vit_small_patch16_224']
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(
|
|
|
|
|
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|