@ -1,7 +1,12 @@
""" Vision Transformer (ViT) in PyTorch
A PyTorch implement of Vision Transformers as described in
' An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale ' - https : / / arxiv . org / abs / 2010.11929
A PyTorch implement of Vision Transformers as described in :
' An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale '
- https : / / arxiv . org / abs / 2010.11929
` How to train your ViT ? Data , Augmentation , and Regularization in Vision Transformers `
- https : / / arxiv . org / abs / 2106. TODO
The official jax code is released and available at https : / / github . com / google - research / vision_transformer
@ -15,7 +20,7 @@ for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy ' s https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 202 0 Ross Wightman
Hacked together by / Copyright 202 1 Ross Wightman
"""
import math
import logging
@ -27,8 +32,8 @@ import torch
import torch . nn as nn
import torch . nn . functional as F
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from . helpers import build_model_with_cfg , overlay_external_default_cfg
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
from . helpers import build_model_with_cfg , named_apply, adapt_input_conv
from . layers import PatchEmbed , Mlp , DropPath , trunc_normal_ , lecun_normal_
from . registry import register_model
@ -40,86 +45,118 @@ def _cfg(url='', **kwargs):
' url ' : url ,
' num_classes ' : 1000 , ' input_size ' : ( 3 , 224 , 224 ) , ' pool_size ' : None ,
' crop_pct ' : .9 , ' interpolation ' : ' bicubic ' , ' fixed_input_size ' : True ,
' mean ' : IMAGENET_ DEFAULT_MEAN, ' std ' : IMAGENET_DEFAULT _STD,
' mean ' : IMAGENET_ INCEPTION_MEAN, ' std ' : IMAGENET_INCEPTION _STD,
' first_conv ' : ' patch_embed.proj ' , ' classifier ' : ' head ' ,
* * kwargs
}
default_cfgs = {
# patch models (my experiments)
# patch models (weights from official Google JAX impl)
' vit_tiny_patch16_224 ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/ '
' Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz ' ) ,
' vit_tiny_patch16_384 ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/ '
' Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_small_patch32_224 ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/ '
' S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz ' ) ,
' vit_small_patch32_384 ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/ '
' S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_small_patch16_224 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth ' ,
) ,
# patch models (weights ported from official Google JAX impl)
' vit_base_patch16_224 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth ' ,
mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ,
) ,
url = ' https://storage.googleapis.com/vit_models/augreg/ '
' S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz ' ) ,
' vit_small_patch16_384 ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/ '
' S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_base_patch32_224 ' : _cfg (
url = ' ' , # no official model weights for this combo, only for in21k
mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
' vit_base_patch16_384 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth ' ,
input_size = ( 3 , 384 , 384 ) , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) , crop_pct = 1.0 ) ,
url = ' https://storage.googleapis.com/vit_models/augreg/ '
' B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz ' ) ,
' vit_base_patch32_384 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth ' ,
input_size = ( 3 , 384 , 384 ) , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) , crop_pct = 1.0 ) ,
' vit_large_patch16_224 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth ' ,
mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
url = ' https://storage.googleapis.com/vit_models/augreg/ '
' B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_base_patch16_224 ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/ '
' B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz ' ) ,
' vit_base_patch16_384 ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/ '
' B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_large_patch32_224 ' : _cfg (
url = ' ' , # no official model weights for this combo, only for in21k
mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
' vit_large_patch16_384 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth ' ,
input_size = ( 3 , 384 , 384 ) , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) , crop_pct = 1.0 ) ,
) ,
' vit_large_patch32_384 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth ' ,
input_size = ( 3 , 384 , 384 ) , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) , crop_pct = 1.0 ) ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_large_patch16_224 ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/ '
' L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz ' ) ,
' vit_large_patch16_384 ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/ '
' L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
# patch models, imagenet21k (weights ported from official Google JAX impl)
' vit_base_patch16_224_in21k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth ' ,
num_classes = 21843 , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
# patch models, imagenet21k (weights from official Google JAX impl)
' vit_tiny_patch16_224_in21k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz ' ,
num_classes = 21843 ) ,
' vit_small_patch32_224_in21k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz ' ,
num_classes = 21843 ) ,
' vit_small_patch16_224_in21k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz ' ,
num_classes = 21843 ) ,
' vit_base_patch32_224_in21k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth ' ,
num_classes = 21843 , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
' vit_large_patch16_224_in21k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth ' ,
num_classes = 21843 , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
url = ' https:// storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz ' ,
num_classes = 21843 ),
' vit_ bas e_patch16_224_in21k' : _cfg (
url = ' https:// storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz ' ,
num_classes = 21843 ),
' vit_large_patch32_224_in21k ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth ' ,
num_classes = 21843 , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
num_classes = 21843 ) ,
' vit_large_patch16_224_in21k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz ' ,
num_classes = 21843 ) ,
' vit_huge_patch14_224_in21k ' : _cfg (
url = ' https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz ' ,
hf_hub = ' timm/vit_huge_patch14_224_in21k ' ,
num_classes = 21843 , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
num_classes = 21843 ),
# deit models (FB weights)
' vit_deit_tiny_patch16_224 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth ' ) ,
' vit_deit_small_patch16_224 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth ' ) ,
' vit_deit_base_patch16_224 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth ' , ) ,
' vit_deit_base_patch16_384 ' : _cfg (
' deit_tiny_patch16_224 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth ' ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ) ,
' deit_small_patch16_224 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth ' ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ) ,
' deit_base_patch16_224 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth ' ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ) ,
' deit_base_patch16_384 ' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_deit_tiny_distilled_patch16_224 ' : _cfg (
mean= IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , input_size= ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' deit_tiny_distilled_patch16_224' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth ' ,
classifier= ( ' head ' , ' head_dist ' ) ) ,
' vit_ deit_small_distilled_patch16_224' : _cfg (
mean= IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , classifier= ( ' head ' , ' head_dist ' ) ) ,
' deit_small_distilled_patch16_224' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth ' ,
classifier= ( ' head ' , ' head_dist ' ) ) ,
' vit_ deit_base_distilled_patch16_224' : _cfg (
mean= IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , classifier= ( ' head ' , ' head_dist ' ) ) ,
' deit_base_distilled_patch16_224' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth ' ,
classifier= ( ' head ' , ' head_dist ' ) ) ,
' vit_ deit_base_distilled_patch16_384' : _cfg (
mean= IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , classifier= ( ' head ' , ' head_dist ' ) ) ,
' deit_base_distilled_patch16_384' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 , classifier = ( ' head ' , ' head_dist ' ) ) ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ,
classifier = ( ' head ' , ' head_dist ' ) ) ,
# ViT ImageNet-21K-P pretraining
# ViT ImageNet-21K-P pretraining by MILL
' vit_base_patch16_224_miil_in21k ' : _cfg (
url = ' https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth ' ,
mean = ( 0 , 0 , 0 ) , std = ( 1 , 1 , 1 ) , crop_pct = 0.875 , interpolation = ' bilinear ' , num_classes = 11221 ,
@ -133,11 +170,11 @@ default_cfgs = {
class Attention ( nn . Module ) :
def __init__ ( self , dim , num_heads = 8 , qkv_bias = False , qk_scale= None , attn_drop= 0. , proj_drop = 0. ) :
def __init__ ( self , dim , num_heads = 8 , qkv_bias = False , attn_drop= 0. , proj_drop = 0. ) :
super ( ) . __init__ ( )
self . num_heads = num_heads
head_dim = dim / / num_heads
self . scale = qk_scale or head_dim * * - 0.5
self . scale = head_dim * * - 0.5
self . qkv = nn . Linear ( dim , dim * 3 , bias = qkv_bias )
self . attn_drop = nn . Dropout ( attn_drop )
@ -161,12 +198,11 @@ class Attention(nn.Module):
class Block ( nn . Module ) :
def __init__ ( self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , qk_scale= None , drop= 0. , attn_drop = 0. ,
def __init__ ( self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop= 0. , attn_drop = 0. ,
drop_path = 0. , act_layer = nn . GELU , norm_layer = nn . LayerNorm ) :
super ( ) . __init__ ( )
self . norm1 = norm_layer ( dim )
self . attn = Attention (
dim , num_heads = num_heads , qkv_bias = qkv_bias , qk_scale = qk_scale , attn_drop = attn_drop , proj_drop = drop )
self . attn = Attention ( dim , num_heads = num_heads , qkv_bias = qkv_bias , attn_drop = attn_drop , proj_drop = drop )
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self . drop_path = DropPath ( drop_path ) if drop_path > 0. else nn . Identity ( )
self . norm2 = norm_layer ( dim )
@ -190,7 +226,7 @@ class VisionTransformer(nn.Module):
"""
def __init__ ( self , img_size = 224 , patch_size = 16 , in_chans = 3 , num_classes = 1000 , embed_dim = 768 , depth = 12 ,
num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , qk_scale= None , representation_size= None , distilled = False ,
num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , representation_size= None , distilled = False ,
drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0. , embed_layer = PatchEmbed , norm_layer = None ,
act_layer = None , weight_init = ' ' ) :
"""
@ -204,7 +240,6 @@ class VisionTransformer(nn.Module):
num_heads ( int ) : number of attention heads
mlp_ratio ( int ) : ratio of mlp hidden dim to embedding dim
qkv_bias ( bool ) : enable bias for qkv if True
qk_scale ( float ) : override default qk scale of head_dim * * - 0.5 if set
representation_size ( Optional [ int ] ) : enable and set representation layer ( pre - logits ) to this value if set
distilled ( bool ) : model includes a distillation token and head as in DeiT models
drop_rate ( float ) : dropout rate
@ -233,8 +268,8 @@ class VisionTransformer(nn.Module):
dpr = [ x . item ( ) for x in torch . linspace ( 0 , drop_path_rate , depth ) ] # stochastic depth decay rule
self . blocks = nn . Sequential ( * [
Block (
dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , qkv_bias = qkv_bias , qk_scale= qk_scal e,
drop= drop_rate , attn_drop= attn_drop_rate , drop_path = dpr [ i ] , norm_layer = norm_layer , act_layer = act_layer )
dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , qkv_bias = qkv_bias , drop= drop_rat e,
attn_drop= attn_drop_rate , drop_path = dpr [ i ] , norm_layer = norm_layer , act_layer = act_layer )
for i in range ( depth ) ] )
self . norm = norm_layer ( embed_dim )
@ -254,16 +289,17 @@ class VisionTransformer(nn.Module):
if distilled :
self . head_dist = nn . Linear ( self . embed_dim , self . num_classes ) if num_classes > 0 else nn . Identity ( )
# Weight init
assert weight_init in ( ' jax ' , ' jax_nlhb ' , ' nlhb ' , ' ' )
head_bias = - math . log ( self . num_classes ) if ' nlhb ' in weight_init else 0.
self . init_weights ( weight_init )
def init_weights ( self , mode = ' ' ) :
assert mode in ( ' jax ' , ' jax_nlhb ' , ' nlhb ' , ' ' )
head_bias = - math . log ( self . num_classes ) if ' nlhb ' in mode else 0.
trunc_normal_ ( self . pos_embed , std = .02 )
if self . dist_token is not None :
trunc_normal_ ( self . dist_token , std = .02 )
if weight_init . startswith ( ' jax ' ) :
if mode . startswith ( ' jax ' ) :
# leave cls token as zeros to match jax impl
for n , m in self . named_modules ( ) :
_init_vit_weights ( m , n , head_bias = head_bias , jax_impl = True )
named_apply ( partial ( _init_vit_weights , head_bias = head_bias , jax_impl = True ) , self )
else :
trunc_normal_ ( self . cls_token , std = .02 )
self . apply ( _init_vit_weights )
@ -272,6 +308,10 @@ class VisionTransformer(nn.Module):
# this fn left here for compat with downstream users
_init_vit_weights ( m )
@torch.jit.ignore ( )
def load_pretrained ( self , checkpoint_path , prefix = ' ' ) :
_load_weights ( self , checkpoint_path , prefix )
@torch.jit.ignore
def no_weight_decay ( self ) :
return { ' pos_embed ' , ' cls_token ' , ' dist_token ' }
@ -317,39 +357,116 @@ class VisionTransformer(nn.Module):
return x
def _init_vit_weights ( m , n : str = ' ' , head_bias : float = 0. , jax_impl : bool = False ) :
def _init_vit_weights ( m odule: nn . Module , n ame : str = ' ' , head_bias : float = 0. , jax_impl : bool = False ) :
""" ViT weight initialization
* When called without n , head_bias , jax_impl args it will behave exactly the same
as my original init for compatibility with prev hparam / downstream use cases ( ie DeiT ) .
* When called w / valid n ( module name ) and jax_impl = True , will ( hopefully ) match JAX impl
"""
if isinstance ( m , nn . Linear ) :
if n . startswith ( ' head ' ) :
nn . init . zeros_ ( m . weight )
nn . init . constant_ ( m . bias , head_bias )
elif n . startswith ( ' pre_logits ' ) :
lecun_normal_ ( m . weight )
nn . init . zeros_ ( m . bias )
if isinstance ( m odule , nn . Linear ) :
if n ame . startswith ( ' head ' ) :
nn . init . zeros_ ( m odule . weight )
nn . init . constant_ ( m odule . bias , head_bias )
elif n ame . startswith ( ' pre_logits ' ) :
lecun_normal_ ( m odule . weight )
nn . init . zeros_ ( m odule . bias )
else :
if jax_impl :
nn . init . xavier_uniform_ ( m . weight )
if m . bias is not None :
if ' mlp ' in n :
nn . init . normal_ ( m . bias , std = 1e-6 )
nn . init . xavier_uniform_ ( m odule . weight )
if m odule . bias is not None :
if ' mlp ' in n ame :
nn . init . normal_ ( m odule . bias , std = 1e-6 )
else :
nn . init . zeros_ ( m . bias )
nn . init . zeros_ ( m odule . bias )
else :
trunc_normal_ ( m . weight , std = .02 )
if m . bias is not None :
nn . init . zeros_ ( m . bias )
elif jax_impl and isinstance ( m , nn . Conv2d ) :
trunc_normal_ ( m odule . weight , std = .02 )
if m odule . bias is not None :
nn . init . zeros_ ( m odule . bias )
elif jax_impl and isinstance ( m odule , nn . Conv2d ) :
# NOTE conv was left to pytorch default in my original init
lecun_normal_ ( m . weight )
if m . bias is not None :
nn . init . zeros_ ( m . bias )
elif isinstance ( m , nn . LayerNorm ) :
nn . init . zeros_ ( m . bias )
nn . init . ones_ ( m . weight )
lecun_normal_ ( module . weight )
if module . bias is not None :
nn . init . zeros_ ( module . bias )
elif isinstance ( module , ( nn . LayerNorm , nn . GroupNorm , nn . BatchNorm2d ) ) :
nn . init . zeros_ ( module . bias )
nn . init . ones_ ( module . weight )
@torch.no_grad ( )
def _load_weights ( model : VisionTransformer , checkpoint_path : str , prefix : str = ' ' ) :
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
"""
import numpy as np
def _n2p ( w , t = True ) :
if w . ndim == 4 and w . shape [ 0 ] == w . shape [ 1 ] == w . shape [ 2 ] == 1 :
w = w . flatten ( )
if t :
if w . ndim == 4 :
w = w . transpose ( [ 3 , 2 , 0 , 1 ] )
elif w . ndim == 3 :
w = w . transpose ( [ 2 , 0 , 1 ] )
elif w . ndim == 2 :
w = w . transpose ( [ 1 , 0 ] )
return torch . from_numpy ( w )
w = np . load ( checkpoint_path )
if not prefix and ' opt/target/embedding/kernel ' in w :
prefix = ' opt/target/ '
if hasattr ( model . patch_embed , ' backbone ' ) :
# hybrid
backbone = model . patch_embed . backbone
stem_only = not hasattr ( backbone , ' stem ' )
stem = backbone if stem_only else backbone . stem
stem . conv . weight . copy_ ( adapt_input_conv ( stem . conv . weight . shape [ 1 ] , _n2p ( w [ f ' { prefix } conv_root/kernel ' ] ) ) )
stem . norm . weight . copy_ ( _n2p ( w [ f ' { prefix } gn_root/scale ' ] ) )
stem . norm . bias . copy_ ( _n2p ( w [ f ' { prefix } gn_root/bias ' ] ) )
if not stem_only :
for i , stage in enumerate ( backbone . stages ) :
for j , block in enumerate ( stage . blocks ) :
bp = f ' { prefix } block { i + 1 } /unit { j + 1 } / '
for r in range ( 3 ) :
getattr ( block , f ' conv { r + 1 } ' ) . weight . copy_ ( _n2p ( w [ f ' { bp } conv { r + 1 } /kernel ' ] ) )
getattr ( block , f ' norm { r + 1 } ' ) . weight . copy_ ( _n2p ( w [ f ' { bp } gn { r + 1 } /scale ' ] ) )
getattr ( block , f ' norm { r + 1 } ' ) . bias . copy_ ( _n2p ( w [ f ' { bp } gn { r + 1 } /bias ' ] ) )
if block . downsample is not None :
block . downsample . conv . weight . copy_ ( _n2p ( w [ f ' { bp } conv_proj/kernel ' ] ) )
block . downsample . norm . weight . copy_ ( _n2p ( w [ f ' { bp } gn_proj/scale ' ] ) )
block . downsample . norm . bias . copy_ ( _n2p ( w [ f ' { bp } gn_proj/bias ' ] ) )
embed_conv_w = _n2p ( w [ f ' { prefix } embedding/kernel ' ] )
else :
embed_conv_w = adapt_input_conv (
model . patch_embed . proj . weight . shape [ 1 ] , _n2p ( w [ f ' { prefix } embedding/kernel ' ] ) )
model . patch_embed . proj . weight . copy_ ( embed_conv_w )
model . patch_embed . proj . bias . copy_ ( _n2p ( w [ f ' { prefix } embedding/bias ' ] ) )
model . cls_token . copy_ ( _n2p ( w [ f ' { prefix } cls ' ] , t = False ) )
pos_embed_w = _n2p ( w [ f ' { prefix } Transformer/posembed_input/pos_embedding ' ] , t = False )
if pos_embed_w . shape != model . pos_embed . shape :
pos_embed_w = resize_pos_embed ( # resize pos embedding when different size from pretrained weights
pos_embed_w , model . pos_embed , getattr ( model , ' num_tokens ' , 1 ) , model . patch_embed . grid_size )
model . pos_embed . copy_ ( pos_embed_w )
model . norm . weight . copy_ ( _n2p ( w [ f ' { prefix } Transformer/encoder_norm/scale ' ] ) )
model . norm . bias . copy_ ( _n2p ( w [ f ' { prefix } Transformer/encoder_norm/bias ' ] ) )
if model . head . bias . shape [ 0 ] == w [ f ' { prefix } head/bias ' ] . shape [ - 1 ] :
model . head . weight . copy_ ( _n2p ( w [ f ' { prefix } head/kernel ' ] ) )
model . head . bias . copy_ ( _n2p ( w [ f ' { prefix } head/bias ' ] ) )
for i , block in enumerate ( model . blocks . children ( ) ) :
block_prefix = f ' { prefix } Transformer/encoderblock_ { i } / '
mha_prefix = block_prefix + ' MultiHeadDotProductAttention_1/ '
block . norm1 . weight . copy_ ( _n2p ( w [ f ' { block_prefix } LayerNorm_0/scale ' ] ) )
block . norm1 . bias . copy_ ( _n2p ( w [ f ' { block_prefix } LayerNorm_0/bias ' ] ) )
block . attn . qkv . weight . copy_ ( torch . cat ( [
_n2p ( w [ f ' { mha_prefix } { n } /kernel ' ] , t = False ) . flatten ( 1 ) . T for n in ( ' query ' , ' key ' , ' value ' ) ] ) )
block . attn . qkv . bias . copy_ ( torch . cat ( [
_n2p ( w [ f ' { mha_prefix } { n } /bias ' ] , t = False ) . reshape ( - 1 ) for n in ( ' query ' , ' key ' , ' value ' ) ] ) )
block . attn . proj . weight . copy_ ( _n2p ( w [ f ' { mha_prefix } out/kernel ' ] ) . flatten ( 1 ) )
block . attn . proj . bias . copy_ ( _n2p ( w [ f ' { mha_prefix } out/bias ' ] ) )
for r in range ( 2 ) :
getattr ( block . mlp , f ' fc { r + 1 } ' ) . weight . copy_ ( _n2p ( w [ f ' { block_prefix } MlpBlock_3/Dense_ { r } /kernel ' ] ) )
getattr ( block . mlp , f ' fc { r + 1 } ' ) . bias . copy_ ( _n2p ( w [ f ' { block_prefix } MlpBlock_3/Dense_ { r } /bias ' ] ) )
block . norm2 . weight . copy_ ( _n2p ( w [ f ' { block_prefix } LayerNorm_2/scale ' ] ) )
block . norm2 . bias . copy_ ( _n2p ( w [ f ' { block_prefix } LayerNorm_2/bias ' ] ) )
def resize_pos_embed ( posemb , posemb_new , num_tokens = 1 , gs_new = ( ) ) :
@ -413,34 +530,64 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
default_cfg = default_cfg ,
representation_size = repr_size ,
pretrained_filter_fn = checkpoint_filter_fn ,
pretrained_custom_load = ' npz ' in default_cfg [ ' url ' ] ,
* * kwargs )
return model
@register_model
def vit_tiny_patch16_224 ( pretrained = False , * * kwargs ) :
""" ViT-Tiny (Vit-Ti/16)
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , * * kwargs )
model = _create_vision_transformer ( ' vit_tiny_patch16_224 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_tiny_patch16_384 ( pretrained = False , * * kwargs ) :
""" ViT-Tiny (Vit-Ti/16) @ 384x384.
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , * * kwargs )
model = _create_vision_transformer ( ' vit_tiny_patch16_384 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_small_patch32_224 ( pretrained = False , * * kwargs ) :
""" ViT-Small (ViT-S/32)
"""
model_kwargs = dict ( patch_size = 32 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer ( ' vit_small_patch32_224 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_small_patch32_384 ( pretrained = False , * * kwargs ) :
""" ViT-Small (ViT-S/32) at 384x384.
"""
model_kwargs = dict ( patch_size = 32 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer ( ' vit_small_patch32_384 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_small_patch16_224 ( pretrained = False , * * kwargs ) :
""" My custom ' small ' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
NOTE :
* this differs from the DeiT based ' small ' definitions with embed_dim = 384 , depth = 12 , num_heads = 6
* this model does not have a bias for QKV ( unlike the official ViT and DeiT models )
""" ViT-Small (ViT-S/16)
NOTE I ' ve replaced my previous ' small ' model definition and weights with the small variant from the DeiT paper
"""
model_kwargs = dict (
patch_size = 16 , embed_dim = 768 , depth = 8 , num_heads = 8 , mlp_ratio = 3. ,
qkv_bias = False , norm_layer = nn . LayerNorm , * * kwargs )
if pretrained :
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
model_kwargs . setdefault ( ' qk_scale ' , 768 * * - 0.5 )
model_kwargs = dict ( patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer ( ' vit_small_patch16_224 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_base_patch16_224 ( pretrained = False , * * kwargs ) :
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 1 k weights fine - tuned from in21k @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
def vit_ small_patch16_38 4( pretrained = False , * * kwargs ) :
""" ViT- Small (ViT-S/16)
NOTE I ' ve replaced my previous ' small ' model definition and weights with the small variant from the DeiT paper
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer ( ' vit_base_patch16_224 ' , pretrained = pretrained , * * model_kwargs )
model_kwargs = dict ( patch_size = 16 , embed_dim = 384, depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer ( ' vit_ small_patch16_38 4' , pretrained = pretrained , * * model_kwargs )
return model
@ -453,6 +600,26 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
return model
@register_model
def vit_base_patch32_384 ( pretrained = False , * * kwargs ) :
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 1 k weights fine - tuned from in21k @ 384 x384 , source https : / / github . com / google - research / vision_transformer .
"""
model_kwargs = dict ( patch_size = 32 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer ( ' vit_base_patch32_384 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_base_patch16_224 ( pretrained = False , * * kwargs ) :
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 1 k weights fine - tuned from in21k @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer ( ' vit_base_patch16_224 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_base_patch16_384 ( pretrained = False , * * kwargs ) :
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
@ -464,31 +631,31 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
@register_model
def vit_base_patch32_384 ( pretrained = False , * * kwargs ) :
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 1 k weights fine - tuned from in21k @ 384 x384 , source https : / / github . com / google - research / vision_transformer .
def vit_large_patch32_224 ( pretrained = False , * * kwargs ) :
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
"""
model_kwargs = dict ( patch_size = 32 , embed_dim = 768, depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer ( ' vit_ base_patch32_38 4' , pretrained = pretrained , * * model_kwargs )
model_kwargs = dict ( patch_size = 32 , embed_dim = 1024, depth = 24 , num_heads = 16 , * * kwargs )
model = _create_vision_transformer ( ' vit_ large_patch32_22 4' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_large_patch 16_22 4( pretrained = False , * * kwargs ) :
def vit_large_patch 32_38 4( pretrained = False , * * kwargs ) :
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 1 k weights fine - tuned from in21k @ 224x22 4, source https : / / github . com / google - research / vision_transformer .
ImageNet - 1 k weights fine - tuned from in21k @ 384x38 4, source https : / / github . com / google - research / vision_transformer .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , * * kwargs )
model = _create_vision_transformer ( ' vit_large_patch 16_22 4' , pretrained = pretrained , * * model_kwargs )
model_kwargs = dict ( patch_size = 32 , embed_dim = 1024 , depth = 24 , num_heads = 16 , * * kwargs )
model = _create_vision_transformer ( ' vit_large_patch 32_38 4' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_large_patch32_224 ( pretrained = False , * * kwargs ) :
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
def vit_large_patch16_224 ( pretrained = False , * * kwargs ) :
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 1 k weights fine - tuned from in21k @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
"""
model_kwargs = dict ( patch_size = 32 , embed_dim = 1024 , depth = 24 , num_heads = 16 , * * kwargs )
model = _create_vision_transformer ( ' vit_large_patch 32 _224' , pretrained = pretrained , * * model_kwargs )
model_kwargs = dict ( patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , * * kwargs )
model = _create_vision_transformer ( ' vit_large_patch 16 _224' , pretrained = pretrained , * * model_kwargs )
return model
@ -503,23 +670,32 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
@register_model
def vit_ large_patch32_384 ( pretrained = False , * * kwargs ) :
""" ViT- Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929 ).
ImageNet - 1k weights fine - tuned from in21k @ 384 x38 4, source https : / / github . com / google - research / vision_transformer .
def vit_ tiny_patch16_224_in21k ( pretrained = False , * * kwargs ) :
""" ViT- Tiny (Vit-Ti/16 ).
ImageNet - 21k weights @ 224 x22 4, source https : / / github . com / google - research / vision_transformer .
"""
model_kwargs = dict ( patch_size = 32, embed_dim = 1024 , depth = 24 , num_heads = 16 , * * kwargs )
model = _create_vision_transformer ( ' vit_ large_patch32_384 ' , pretrained = pretrained , * * model_kwargs )
model_kwargs = dict ( patch_size = 16, embed_dim = 192 , depth = 12 , num_heads = 3 , * * kwargs )
model = _create_vision_transformer ( ' vit_ tiny_patch16_224_in21k ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_ base_patch16 _224_in21k( pretrained = False , * * kwargs ) :
""" ViT- Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
def vit_ small_patch32 _224_in21k( pretrained = False , * * kwargs ) :
""" ViT- Small (ViT-S/16)
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
"""
model_kwargs = dict (
patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , representation_size = 768 , * * kwargs )
model = _create_vision_transformer ( ' vit_base_patch16_224_in21k ' , pretrained = pretrained , * * model_kwargs )
model_kwargs = dict ( patch_size = 32 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer ( ' vit_small_patch32_224_in21k ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_small_patch16_224_in21k ( pretrained = False , * * kwargs ) :
""" ViT-Small (ViT-S/16)
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer ( ' vit_small_patch16_224_in21k ' , pretrained = pretrained , * * model_kwargs )
return model
@ -535,13 +711,13 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
@register_model
def vit_ larg e_patch16_224_in21k( pretrained = False , * * kwargs ) :
""" ViT- Large model (ViT-L /16) from original paper (https://arxiv.org/abs/2010.11929).
def vit_ bas e_patch16_224_in21k( pretrained = False , * * kwargs ) :
""" ViT- Base model (ViT-B /16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
"""
model_kwargs = dict (
patch_size = 16 , embed_dim = 1024, depth = 24 , num_heads = 16 , representation_size = 1024 , * * kwargs )
model = _create_vision_transformer ( ' vit_ larg e_patch16_224_in21k' , pretrained = pretrained , * * model_kwargs )
patch_size = 16 , embed_dim = 768, depth = 12 , num_heads = 12 , representation_size = 768 , * * kwargs )
model = _create_vision_transformer ( ' vit_ bas e_patch16_224_in21k' , pretrained = pretrained , * * model_kwargs )
return model
@ -556,6 +732,17 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
return model
@register_model
def vit_large_patch16_224_in21k ( pretrained = False , * * kwargs ) :
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet - 21 k weights @ 224 x224 , source https : / / github . com / google - research / vision_transformer .
"""
model_kwargs = dict (
patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , representation_size = 1024 , * * kwargs )
model = _create_vision_transformer ( ' vit_large_patch16_224_in21k ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_huge_patch14_224_in21k ( pretrained = False , * * kwargs ) :
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
@ -569,86 +756,86 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
@register_model
def vit_ deit_tiny_patch16_224( pretrained = False , * * kwargs ) :
def deit_tiny_patch16_224( pretrained = False , * * kwargs ) :
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , * * kwargs )
model = _create_vision_transformer ( ' vit_ deit_tiny_patch16_224' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' deit_tiny_patch16_224' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_ deit_small_patch16_224( pretrained = False , * * kwargs ) :
def deit_small_patch16_224( pretrained = False , * * kwargs ) :
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer ( ' vit_ deit_small_patch16_224' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' deit_small_patch16_224' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_ deit_base_patch16_224( pretrained = False , * * kwargs ) :
def deit_base_patch16_224( pretrained = False , * * kwargs ) :
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer ( ' vit_ deit_base_patch16_224' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' deit_base_patch16_224' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_ deit_base_patch16_384( pretrained = False , * * kwargs ) :
def deit_base_patch16_384( pretrained = False , * * kwargs ) :
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer ( ' vit_ deit_base_patch16_384' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' deit_base_patch16_384' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_ deit_tiny_distilled_patch16_224( pretrained = False , * * kwargs ) :
def deit_tiny_distilled_patch16_224( pretrained = False , * * kwargs ) :
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , * * kwargs )
model = _create_vision_transformer (
' vit_ deit_tiny_distilled_patch16_224' , pretrained = pretrained , distilled = True , * * model_kwargs )
' deit_tiny_distilled_patch16_224' , pretrained = pretrained , distilled = True , * * model_kwargs )
return model
@register_model
def vit_ deit_small_distilled_patch16_224( pretrained = False , * * kwargs ) :
def deit_small_distilled_patch16_224( pretrained = False , * * kwargs ) :
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer (
' vit_ deit_small_distilled_patch16_224' , pretrained = pretrained , distilled = True , * * model_kwargs )
' deit_small_distilled_patch16_224' , pretrained = pretrained , distilled = True , * * model_kwargs )
return model
@register_model
def vit_ deit_base_distilled_patch16_224( pretrained = False , * * kwargs ) :
def deit_base_distilled_patch16_224( pretrained = False , * * kwargs ) :
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer (
' vit_ deit_base_distilled_patch16_224' , pretrained = pretrained , distilled = True , * * model_kwargs )
' deit_base_distilled_patch16_224' , pretrained = pretrained , distilled = True , * * model_kwargs )
return model
@register_model
def vit_ deit_base_distilled_patch16_384( pretrained = False , * * kwargs ) :
def deit_base_distilled_patch16_384( pretrained = False , * * kwargs ) :
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer (
' vit_ deit_base_distilled_patch16_384' , pretrained = pretrained , distilled = True , * * model_kwargs )
' deit_base_distilled_patch16_384' , pretrained = pretrained , distilled = True , * * model_kwargs )
return model