@ -1,23 +1,18 @@
""" Vision Transformer (ViT) in PyTorch
This is a WIP attempt to implement Vision Transformers as described in
' An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale ' -
https : / / openreview . net / pdf ? id = YicbFdNTTy
A PyTorch implement of Vision Transformers as described in
' An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale ' - https : / / arxiv . org / abs / 2010.11929
The paper is currently under review and there is no official reference impl . The
code here is likely to change in the future and I will not make an effort to maintain
backwards weight compatibility when it does .
The official jax code is released and available at https : / / github . com / google - research / vision_transformer
Status / TODO :
* Trained ( supervised on ImageNet - 1 k ) my custom ' small ' patch model to ~ 75 top - 1 after 4 days , 2 x GPU ,
no dropout or stochastic depth active
* Need more time for supervised training results with dropout and drop connect active , hparam tuning
* Need more GPUs for SSL or unsupervised pretraining on OpenImages w / ImageNet fine - tune
* There are likely mistakes . If you notice any , I ' d love to improve this. This is my first time
fiddling with transformers / multi - head attn .
* Hopefully end up with worthwhile pretrained model at some point . . .
* Models updated to be compatible with official impl . Args added to support backward compat for old PyTorch weights .
* Weights ported from official jax impl for 384 x384 base and small models , 16 x16 and 32 x32 patches .
* Trained ( supervised on ImageNet - 1 k ) my custom ' small ' patch model to 77.9 , ' base ' to 79.4 top - 1 with this code .
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w / ImageNet fine - tune in future .
Acknowledgments :
* The paper authors for releasing code and weights , thanks !
* I fixed my class token impl based on Phil Wang ' s https://github.com/lucidrains/vit-pytorch ... check it out
for some einops / einsum fun
* Simple transformer style inspired by Andrej Karpathy ' s https://github.com/karpathy/minGPT
@ -27,6 +22,7 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch . nn as nn
from functools import partial
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from . helpers import load_pretrained
@ -52,13 +48,21 @@ default_cfgs = {
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth ' ,
) ,
' vit_base_patch16_224 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth '
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth ' ,
) ,
' vit_base_patch16_384 ' : _cfg ( input_size = ( 3 , 384 , 384 ) ) ,
' vit_base_patch32_384 ' : _cfg ( input_size = ( 3 , 384 , 384 ) ) ,
' vit_base_patch16_384 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth ' ,
input_size = ( 3 , 384 , 384 ) , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) , crop_pct = 1.0 ) ,
' vit_base_patch32_384 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth ' ,
input_size = ( 3 , 384 , 384 ) , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) , crop_pct = 1.0 ) ,
' vit_large_patch16_224 ' : _cfg ( ) ,
' vit_large_patch16_384 ' : _cfg ( input_size = ( 3 , 384 , 384 ) ) ,
' vit_large_patch32_384 ' : _cfg ( input_size = ( 3 , 384 , 384 ) ) ,
' vit_large_patch16_384 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth ' ,
input_size = ( 3 , 384 , 384 ) , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) , crop_pct = 1.0 ) ,
' vit_large_patch32_384 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth ' ,
input_size = ( 3 , 384 , 384 ) , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) , crop_pct = 1.0 ) ,
' vit_huge_patch16_224 ' : _cfg ( ) ,
' vit_huge_patch32_384 ' : _cfg ( input_size = ( 3 , 384 , 384 ) ) ,
# hybrid models
@ -77,38 +81,35 @@ class Mlp(nn.Module):
self . fc1 = nn . Linear ( in_features , hidden_features )
self . act = act_layer ( )
self . fc2 = nn . Linear ( hidden_features , out_features )
self . drop out = nn . Dropout ( drop ) # seems more common to have Transformer MLP drouput here?
self . drop = nn . Dropout ( drop )
def forward ( self , x ) :
x = self . fc1 ( x )
x = self . act ( x )
x = self . drop ( x )
x = self . fc2 ( x )
x = self . drop out ( x )
x = self . drop ( x )
return x
class Attention ( nn . Module ) :
def __init__ ( self , dim , num_heads = 8 , attn_drop= 0. , proj_drop = 0. ) :
def __init__ ( self , dim , num_heads = 8 , qkv_bias= False , qk_scale = None , attn_drop= 0. , proj_drop = 0. ) :
super ( ) . __init__ ( )
self . scale = 1. / dim * * 0.5
self . num_heads = num_heads
head_dim = dim / / num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self . scale = qk_scale or head_dim * * - 0.5
self . qkv = nn . Linear ( dim , dim * 3 , bias = False )
self . qkv = nn . Linear ( dim , dim * 3 , bias = qkv_bias )
self . attn_drop = nn . Dropout ( attn_drop )
self . proj = nn . Linear ( dim , dim )
self . proj_drop = nn . Dropout ( proj_drop )
def forward ( self , x , attn_mask = None ):
def forward ( self , x ):
B , N , C = x . shape
qkv = self . qkv ( x ) . reshape ( B , N , 3 , self . num_heads , C / / self . num_heads )
q , k , v = qkv [ : , : , 0 ] . transpose ( 1 , 2 ) , qkv [ : , : , 1 ] . transpose ( 1 , 2 ) , qkv [ : , : , 2 ] . transpose ( 1 , 2 )
# TODO benchmark vs above
#qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
#q, k, v = qkv
q , k , v = self . qkv ( x ) . reshape ( B , N , 3 , self . num_heads , C / / self . num_heads ) . permute ( 2 , 0 , 3 , 1 , 4 )
attn = ( q @ k . transpose ( - 2 , - 1 ) ) * self . scale
# FIXME support masking
attn = attn . softmax ( dim = - 1 )
attn = self . attn_drop ( attn )
@ -120,52 +121,44 @@ class Attention(nn.Module):
class Block ( nn . Module ) :
def __init__ ( self , dim , num_heads , mlp_ratio = 4. , act_layer = nn . GELU , drop = 0. , drop_path = 0. ) :
def __init__ ( self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , qk_scale = None , drop = 0. , attn_drop = 0. ,
drop_path = 0. , act_layer = nn . GELU , norm_layer = nn . LayerNorm ) :
super ( ) . __init__ ( )
self . norm1 = nn . LayerNorm ( dim )
self . attn = Attention ( dim , num_heads = num_heads , attn_drop = drop , proj_drop = drop )
self . norm1 = norm_layer ( dim )
self . attn = Attention (
dim , num_heads = num_heads , qkv_bias = qkv_bias , qk_scale = qk_scale , attn_drop = attn_drop , proj_drop = drop )
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self . drop_path = DropPath ( drop_path ) if drop_path > 0. else nn . Identity ( )
self . norm2 = n n. LayerNorm ( dim )
self . norm2 = n orm_layer ( dim )
mlp_hidden_dim = int ( dim * mlp_ratio )
self . mlp = Mlp ( in_features = dim , hidden_features = mlp_hidden_dim , act_layer = act_layer , drop = drop )
def forward ( self , x , attn_mask = None ):
x = x + self . drop_path ( self . attn ( self . norm1 ( x ) , attn_mask = attn_mask ))
def forward ( self , x ):
x = x + self . drop_path ( self . attn ( self . norm1 ( x ) ))
x = x + self . drop_path ( self . mlp ( self . norm2 ( x ) ) )
return x
class PatchEmbed ( nn . Module ) :
""" Image to Patch Embedding
Unfold image into fixed size patches , flatten into seq , project to embedding dim .
"""
def __init__ ( self , img_size = 224 , patch_size = 16 , in_chans = 3 , embed_dim = 768 , flatten_channels_last = False ):
def __init__ ( self , img_size = 224 , patch_size = 16 , in_chans = 3 , embed_dim = 768 ):
super ( ) . __init__ ( )
img_size = to_2tuple ( img_size )
patch_size = to_2tuple ( patch_size )
assert img_size [ 0 ] % patch_size [ 0 ] == 0 , ' image height must be divisible by the patch height '
assert img_size [ 1 ] % patch_size [ 1 ] == 0 , ' image width must be divisible by the patch width '
num_patches = ( img_size [ 1 ] / / patch_size [ 1 ] ) * ( img_size [ 0 ] / / patch_size [ 0 ] )
patch_dim = in_chans * patch_size [ 0 ] * patch_size [ 1 ]
self . img_size = img_size
self . patch_size = patch_size
self . flatten_channels_last = flatten_channels_last
self . num_patches = num_patches
self . proj = nn . Linear( patch_dim , embed_dim )
self . proj = nn . Conv2d( in_chans , embed_dim , kernel_size = patch_size , stride = patch_size )
def forward ( self , x ) :
B , C , H , W = x . shape
Ph , Pw = self . patch_size
# FIXME look at relaxing size constraints
assert H == self . img_size [ 0 ] and W == self . img_size [ 1 ] , \
f " Input image size ( { H } * { W } ) doesn ' t match model ( { self . img_size [ 0 ] } * { self . img_size [ 1 ] } ). "
if self . flatten_channels_last :
# flatten patches with channels last like the paper (likely using TF)
x = x . unfold ( 2 , Ph , Ph ) . unfold ( 3 , Pw , Pw ) . permute ( 0 , 2 , 3 , 4 , 5 , 1 ) . reshape ( B , - 1 , Ph * Pw * C )
else :
x = x . permute ( 0 , 2 , 3 , 1 ) . unfold ( 1 , Ph , Ph ) . unfold ( 2 , Pw , Pw ) . reshape ( B , - 1 , C * Ph * Pw )
x = self . proj ( x )
x = self . proj ( x ) . flatten ( 2 ) . transpose ( 1 , 2 )
return x
@ -208,37 +201,37 @@ class VisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__ ( self , img_size = 224 , patch_size = 16 , in_chans = 3 , num_classes = 1000 , embed_dim = 768 , depth = 12 ,
num_heads = 12 , mlp_ratio = 4. , mlp_head= False , drop_rate = 0. , drop_path _rate= 0. ,
flatten_channels_last= False , hybrid_backbone = None ) :
num_heads = 12 , mlp_ratio = 4. , qkv_bias= False , qk_scale = None , drop_rate = 0. , attn_drop _rate= 0. ,
drop_path_rate= 0. , hybrid_backbone = None , norm_layer = nn . LayerNorm ) :
super ( ) . __init__ ( )
if hybrid_backbone is not None :
self . patch_embed = HybridEmbed (
hybrid_backbone , img_size = img_size , in_chans = in_chans , embed_dim = embed_dim )
else :
self . patch_embed = PatchEmbed (
img_size = img_size , patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dim ,
flatten_channels_last = flatten_channels_last )
img_size = img_size , patch_size = patch_size , in_chans = in_chans , embed_dim = embed_dim )
num_patches = self . patch_embed . num_patches
self . pos_embed = nn . Parameter ( torch . zeros ( 1 , num_patches + 1 , embed_dim ) )
self . cls_token = nn . Parameter ( torch . zeros ( 1 , 1 , embed_dim ) )
self . pos_embed = nn . Parameter ( torch . zeros ( 1 , num_patches + 1 , embed_dim ) )
self . pos_drop = nn . Dropout ( p = drop_rate )
dpr = [ x . item ( ) for x in torch . linspace ( 0 , drop_path_rate , depth ) ] # stochastic depth decay rule
self . blocks = nn . ModuleList ( [
Block ( dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , drop = drop_rate , drop_path = dpr [ i ] )
Block (
dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , qkv_bias = qkv_bias , qk_scale = qk_scale ,
drop = drop_rate , attn_drop = attn_drop_rate , drop_path = dpr [ i ] , norm_layer = norm_layer )
for i in range ( depth ) ] )
self . norm = norm_layer ( embed_dim )
self . norm = nn . LayerNorm ( embed_dim )
if mlp_head :
# paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
self . head = Mlp ( embed_dim , int ( embed_dim * mlp_ratio ) , num_classes )
else :
# with a single Linear layer as head, the param count within rounding of paper
self . head = nn . Linear ( embed_dim , num_classes )
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
#self.repr = nn.Linear(embed_dim, representation_size)
#self.repr_act = nn.Tanh()
# FIXME not quite sure what the proper weight init is supposed to be,
# normal / trunc normal w/ std == .02 similar to other Bert like transformers
trunc_normal_ ( self . pos_embed , std = .02 ) # embeddings same as weights?
# Classifier head
self . head = nn . Linear ( embed_dim , num_classes )
trunc_normal_ ( self . pos_embed , std = .02 )
trunc_normal_ ( self . cls_token , std = .02 )
self . apply ( self . _init_weights )
@ -255,55 +248,80 @@ class VisionTransformer(nn.Module):
def no_weight_decay ( self ) :
return { ' pos_embed ' , ' cls_token ' }
def forward ( self , x , attn_mask = None ):
def forward ( self , x ):
B = x . shape [ 0 ]
x = self . patch_embed ( x )
cls_tokens = self . cls_token . expand ( B , - 1 , - 1 ) # stole cls_tokens impl from Phil Wang, thanks
x = torch . cat ( ( cls_tokens , x ) , dim = 1 )
x + = self . pos_embed
x = x + self . pos_embed
x = self . pos_drop ( x )
for blk in self . blocks :
x = blk ( x , attn_mask = attn_mask )
x = blk ( x )
x = self . norm ( x [: , 0 ] )
x = self . head ( x )
x = self . norm ( x )
x = self . head ( x [: , 0 ] )
return x
def _conv_filter ( state_dict , patch_size = 16 ) :
""" convert patch embedding weight from manual patchify + linear proj to conv """
out_dict = { }
for k , v in state_dict . items ( ) :
if ' patch_embed.proj.weight ' in k :
v = v . reshape ( ( v . shape [ 0 ] , 3 , patch_size , patch_size ) )
out_dict [ k ] = v
return out_dict
@register_model
def vit_small_patch16_224 ( pretrained = False , * * kwargs ) :
if pretrained :
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
kwargs . setdefault ( ' qk_scale ' , 768 * * - 0.5 )
model = VisionTransformer ( patch_size = 16 , embed_dim = 768 , depth = 8 , num_heads = 8 , mlp_ratio = 3. , * * kwargs )
model . default_cfg = default_cfgs [ ' vit_small_patch16_224 ' ]
if pretrained :
load_pretrained (
model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) , filter_fn = _conv_filter )
return model
@register_model
def vit_base_patch16_224 ( pretrained = False , * * kwargs ) :
if pretrained :
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
kwargs . setdefault ( ' qk_scale ' , 768 * * - 0.5 )
model = VisionTransformer ( patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , * * kwargs )
model . default_cfg = default_cfgs [ ' vit_base_patch16_224 ' ]
if pretrained :
load_pretrained (
model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) , filter_fn = _conv_filter )
return model
@register_model
def vit_base_patch16_384 ( pretrained = False , * * kwargs ) :
model = VisionTransformer (
img_size = 384 , patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , * * kwargs )
img_size = 384 , patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , qkv_bias = True ,
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
model . default_cfg = default_cfgs [ ' vit_base_patch16_384 ' ]
if pretrained :
load_pretrained (
model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
return model
@register_model
def vit_base_patch32_384 ( pretrained = False , * * kwargs ) :
model = VisionTransformer (
img_size = 384 , patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , * * kwargs )
img_size = 384 , patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , qkv_bias = True ,
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
model . default_cfg = default_cfgs [ ' vit_base_patch32_384 ' ]
if pretrained :
load_pretrained (
model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
return model
@ -317,16 +335,24 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
@register_model
def vit_large_patch16_384 ( pretrained = False , * * kwargs ) :
model = VisionTransformer (
img_size = 384 , patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , * * kwargs )
img_size = 384 , patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
model . default_cfg = default_cfgs [ ' vit_large_patch16_384 ' ]
if pretrained :
load_pretrained (
model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
return model
@register_model
def vit_large_patch32_384 ( pretrained = False , * * kwargs ) :
model = VisionTransformer (
img_size = 384 , patch_size = 32 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , * * kwargs )
img_size = 384 , patch_size = 32 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 , qkv_bias = True ,
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
model . default_cfg = default_cfgs [ ' vit_large_patch32_384 ' ]
if pretrained :
load_pretrained (
model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
return model
@ -383,5 +409,3 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs):
img_size = 224 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 , hybrid_backbone = backbone , * * kwargs )
model . default_cfg = default_cfgs [ ' vit_base_resnet50d_224 ' ]
return model