|
|
|
""" Vision Transformer (ViT) in PyTorch
|
|
|
|
|
|
|
|
A PyTorch implement of Vision Transformers as described in
|
|
|
|
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
|
|
|
|
|
|
|
|
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
|
|
|
|
|
|
|
Acknowledgments:
|
|
|
|
* The paper authors for releasing code and weights, thanks!
|
|
|
|
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
|
|
|
for some einops/einsum fun
|
|
|
|
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
|
|
|
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
|
|
|
|
|
|
|
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
|
|
|
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
"""
|
|
|
|
import math
|
|
|
|
import logging
|
|
|
|
from functools import partial
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from .helpers import load_pretrained
|
|
|
|
from .layers import DropPath, to_2tuple, trunc_normal_
|
|
|
|
from .resnet import resnet26d, resnet50d
|
|
|
|
from .resnetv2 import ResNetV2, StdConv2dSame
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
return {
|
|
|
|
'url': url,
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
|
|
'crop_pct': .9, 'interpolation': 'bicubic',
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
|
|
|
**kwargs
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = {
|
|
|
|
# patch models (my experiments)
|
|
|
|
'vit_small_patch16_224': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
|
|
|
),
|
|
|
|
|
|
|
|
# patch models (weights ported from official Google JAX impl)
|
|
|
|
'vit_base_patch16_224': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
|
|
|
),
|
|
|
|
'vit_base_patch32_224': _cfg(
|
|
|
|
url='', # no official model weights for this combo, only for in21k
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
'vit_base_patch16_384': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
|
|
|
'vit_base_patch32_384': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
|
|
|
'vit_large_patch16_224': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
'vit_large_patch32_224': _cfg(
|
|
|
|
url='', # no official model weights for this combo, only for in21k
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
'vit_large_patch16_384': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
|
|
|
'vit_large_patch32_384': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
|
|
|
|
|
|
|
# patch models, imagenet21k (weights ported from official Google JAX impl)
|
|
|
|
'vit_base_patch16_224_in21k': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
'vit_base_patch32_224_in21k': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
'vit_large_patch16_224_in21k': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
'vit_large_patch32_224_in21k': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
'vit_huge_patch14_224_in21k': _cfg(
|
|
|
|
url='', # FIXME I have weights for this but > 2GB limit for github release binaries
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
|
|
|
|
# hybrid models (weights ported from official Google JAX impl)
|
|
|
|
'vit_base_resnet50_224_in21k': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9),
|
|
|
|
'vit_base_resnet50_384': _cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
|
|
|
|
|
|
|
# hybrid models (my experiments)
|
|
|
|
'vit_small_resnet26d_224': _cfg(),
|
|
|
|
'vit_small_resnet50d_s3_224': _cfg(),
|
|
|
|
'vit_base_resnet26d_224': _cfg(),
|
|
|
|
'vit_base_resnet50d_224': _cfg(),
|
|
|
|
|
|
|
|
# deit models (FB weights)
|
|
|
|
'vit_deit_tiny_patch16_224': _cfg(
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
|
|
|
'vit_deit_small_patch16_224': _cfg(
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
|
|
|
'vit_deit_base_patch16_224': _cfg(
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
|
|
|
|
'vit_deit_base_patch16_384': _cfg(
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
|
|
super().__init__()
|
|
|
|
out_features = out_features or in_features
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
|
|
self.act = act_layer()
|
|
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.fc1(x)
|
|
|
|
x = self.act(x)
|
|
|
|
x = self.drop(x)
|
|
|
|
x = self.fc2(x)
|
|
|
|
x = self.drop(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
|
|
|
super().__init__()
|
|
|
|
self.num_heads = num_heads
|
|
|
|
head_dim = dim // num_heads
|
|
|
|
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
|
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
B, N, C = x.shape
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
|
|
x = self.proj(x)
|
|
|
|
x = self.proj_drop(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
|
super().__init__()
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.attn = Attention(
|
|
|
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
|
|
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = x + self.drop_path(self.attn(self.norm1(x)))
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
|
|
""" Image to Patch Embedding
|
|
|
|
"""
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
|
|
|
super().__init__()
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
patch_size = to_2tuple(patch_size)
|
|
|
|
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
|
|
|
self.img_size = img_size
|
|
|
|
self.patch_size = patch_size
|
|
|
|
self.num_patches = num_patches
|
|
|
|
|
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
# FIXME look at relaxing size constraints
|
|
|
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
|
|
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class HybridEmbed(nn.Module):
|
|
|
|
""" CNN Feature Map Embedding
|
|
|
|
Extract feature map from CNN, flatten, project to embedding dim.
|
|
|
|
"""
|
|
|
|
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
|
|
|
super().__init__()
|
|
|
|
assert isinstance(backbone, nn.Module)
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
self.img_size = img_size
|
|
|
|
self.backbone = backbone
|
|
|
|
if feature_size is None:
|
|
|
|
with torch.no_grad():
|
|
|
|
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
|
|
|
|
# map for all networks, the feature metadata has reliable channel and stride info, but using
|
|
|
|
# stride to calc feature dim requires info about padding of each stage that isn't captured.
|
|
|
|
training = backbone.training
|
|
|
|
if training:
|
|
|
|
backbone.eval()
|
|
|
|
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
|
|
|
if isinstance(o, (list, tuple)):
|
|
|
|
o = o[-1] # last feature if backbone outputs list/tuple of features
|
|
|
|
feature_size = o.shape[-2:]
|
|
|
|
feature_dim = o.shape[1]
|
|
|
|
backbone.train(training)
|
|
|
|
else:
|
|
|
|
feature_size = to_2tuple(feature_size)
|
|
|
|
if hasattr(self.backbone, 'feature_info'):
|
|
|
|
feature_dim = self.backbone.feature_info.channels()[-1]
|
|
|
|
else:
|
|
|
|
feature_dim = self.backbone.num_features
|
|
|
|
self.num_patches = feature_size[0] * feature_size[1]
|
|
|
|
self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.backbone(x)
|
|
|
|
if isinstance(x, (list, tuple)):
|
|
|
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class VisionTransformer(nn.Module):
|
|
|
|
""" Vision Transformer with support for patch or hybrid CNN input stage
|
|
|
|
"""
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
|
|
|
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
|
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None):
|
|
|
|
super().__init__()
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
|
|
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
|
|
|
|
|
|
|
if hybrid_backbone is not None:
|
|
|
|
self.patch_embed = HybridEmbed(
|
|
|
|
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
else:
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
|
|
self.blocks = nn.ModuleList([
|
|
|
|
Block(
|
|
|
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
|
|
|
for i in range(depth)])
|
|
|
|
self.norm = norm_layer(embed_dim)
|
|
|
|
|
|
|
|
# Representation layer
|
|
|
|
if representation_size:
|
|
|
|
self.num_features = representation_size
|
|
|
|
self.pre_logits = nn.Sequential(OrderedDict([
|
|
|
|
('fc', nn.Linear(embed_dim, representation_size)),
|
|
|
|
('act', nn.Tanh())
|
|
|
|
]))
|
|
|
|
else:
|
|
|
|
self.pre_logits = nn.Identity()
|
|
|
|
|
|
|
|
# Classifier head
|
|
|
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
trunc_normal_(self.pos_embed, std=.02)
|
|
|
|
trunc_normal_(self.cls_token, std=.02)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
def no_weight_decay(self):
|
|
|
|
return {'pos_embed', 'cls_token'}
|
|
|
|
|
|
|
|
def get_classifier(self):
|
|
|
|
return self.head
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool=''):
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
B = x.shape[0]
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
|
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
|
|
x = x + self.pos_embed
|
|
|
|
x = self.pos_drop(x)
|
|
|
|
|
|
|
|
for blk in self.blocks:
|
|
|
|
x = blk(x)
|
|
|
|
|
|
|
|
x = self.norm(x)[:, 0]
|
|
|
|
x = self.pre_logits(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.head(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def resize_pos_embed(posemb, posemb_new):
|
|
|
|
# Rescale the grid of position embeddings when loading from state_dict
|
|
|
|
# Adapted from
|
|
|
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
|
|
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
|
|
|
ntok_new = posemb_new.shape[1]
|
|
|
|
if True:
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
|
|
|
ntok_new -= 1
|
|
|
|
else:
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
|
|
|
gs_old = int(math.sqrt(len(posemb_grid)))
|
|
|
|
gs_new = int(math.sqrt(ntok_new))
|
|
|
|
_logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
|
|
|
|
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
|
|
|
posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
|
|
|
|
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
|
|
|
|
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
|
|
|
state_dict['pos_embed'] = posemb
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
|
|
out_dict = {}
|
|
|
|
if 'model' in state_dict:
|
|
|
|
# for deit models
|
|
|
|
state_dict = state_dict['model']
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
|
|
|
# for old models that I trained prior to conv based patchification
|
|
|
|
v = v.reshape(model.patch_embed.proj.weight.shape)
|
|
|
|
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
|
|
|
# to resize pos embedding when using model at different size from pretrained weights
|
|
|
|
v = resize_pos_embed(v, model.pos_embed)
|
|
|
|
out_dict[k] = v
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
|
|
|
default_cfg = default_cfgs[variant]
|
|
|
|
default_num_classes = default_cfg['num_classes']
|
|
|
|
default_img_size = default_cfg['input_size'][-1]
|
|
|
|
|
|
|
|
num_classes = kwargs.pop('num_classes', default_num_classes)
|
|
|
|
img_size = kwargs.pop('img_size', default_img_size)
|
|
|
|
repr_size = kwargs.pop('representation_size', None)
|
|
|
|
if repr_size is not None and num_classes != default_num_classes:
|
|
|
|
# remove representation layer if fine-tuning
|
|
|
|
_logger.info("Removing representation layer for fine-tuning.")
|
|
|
|
repr_size = None
|
|
|
|
|
|
|
|
model = VisionTransformer(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
|
|
|
if pretrained:
|
|
|
|
load_pretrained(
|
|
|
|
model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),
|
|
|
|
filter_fn=partial(checkpoint_filter_fn, model=model))
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_small_patch16_224(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
|
|
|
|
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
|
|
|
|
if pretrained:
|
|
|
|
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
|
|
|
|
model_kwargs.setdefault('qk_scale', 768 ** -0.5)
|
|
|
|
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_base_patch32_224(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_base_patch16_384(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_base_patch32_384(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_large_patch32_224(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_large_patch16_384(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_base_patch16_384_in21k(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
# @register_model
|
|
|
|
# def vit_large_patch16_384_in21k(pretrained=False, **kwargs):
|
|
|
|
# model_kwargs = dict(
|
|
|
|
# patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs)
|
|
|
|
# model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
# return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, representation_size=1280, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
|
|
|
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
|
|
|
backbone = ResNetV2(
|
|
|
|
layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='')
|
|
|
|
model_kwargs = dict(
|
|
|
|
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone,
|
|
|
|
representation_size=768, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
|
|
|
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
|
|
|
backbone = ResNetV2(
|
|
|
|
layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='')
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_small_resnet26d_224(pretrained=False, **kwargs):
|
|
|
|
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing
|
|
|
|
backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4])
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
|
|
|
|
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing
|
|
|
|
backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[3])
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_base_resnet26d_224(pretrained=False, **kwargs):
|
|
|
|
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing
|
|
|
|
backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4])
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_base_resnet50d_224(pretrained=False, **kwargs):
|
|
|
|
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing
|
|
|
|
backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[4])
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_deit_small_patch16_224(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_deit_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def vit_deit_base_patch16_384(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|