@ -28,7 +28,7 @@ import torch.nn as nn
import torch . nn . functional as F
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from . helpers import build_model_with_cfg , overlay_external_default_cfg
from . helpers import build_model_with_cfg , named_apply, adapt_input_conv
from . layers import PatchEmbed , Mlp , DropPath , trunc_normal_ , lecun_normal_
from . registry import register_model
@ -47,9 +47,18 @@ def _cfg(url='', **kwargs):
default_cfgs = {
# patch models (my experiments)
# FIXME weights coming
' vit_tiny_patch16_224 ' : _cfg (
url = ' ' ,
mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ,
) ,
' vit_small_patch16_224 ' : _cfg (
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth ' ,
url = ' ' ,
mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ,
) ,
' vit_small_patch32_224 ' : _cfg (
url = ' ' ,
mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ,
) ,
# patch models (weights ported from official Google JAX impl)
@ -97,29 +106,29 @@ default_cfgs = {
num_classes = 21843 , mean = ( 0.5 , 0.5 , 0.5 ) , std = ( 0.5 , 0.5 , 0.5 ) ) ,
# deit models (FB weights)
' vit_ deit_tiny_patch16_224' : _cfg (
' deit_tiny_patch16_224' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth ' ) ,
' vit_ deit_small_patch16_224' : _cfg (
' deit_small_patch16_224' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth ' ) ,
' vit_ deit_base_patch16_224' : _cfg (
' deit_base_patch16_224' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth ' , ) ,
' vit_ deit_base_patch16_384' : _cfg (
' deit_base_patch16_384' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' vit_ deit_tiny_distilled_patch16_224' : _cfg (
' deit_tiny_distilled_patch16_224' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth ' ,
classifier = ( ' head ' , ' head_dist ' ) ) ,
' vit_ deit_small_distilled_patch16_224' : _cfg (
' deit_small_distilled_patch16_224' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth ' ,
classifier = ( ' head ' , ' head_dist ' ) ) ,
' vit_ deit_base_distilled_patch16_224' : _cfg (
' deit_base_distilled_patch16_224' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth ' ,
classifier = ( ' head ' , ' head_dist ' ) ) ,
' vit_ deit_base_distilled_patch16_384' : _cfg (
' deit_base_distilled_patch16_384' : _cfg (
url = ' https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 , classifier = ( ' head ' , ' head_dist ' ) ) ,
# ViT ImageNet-21K-P pretraining
# ViT ImageNet-21K-P pretraining by MILL
' vit_base_patch16_224_miil_in21k ' : _cfg (
url = ' https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth ' ,
mean = ( 0 , 0 , 0 ) , std = ( 1 , 1 , 1 ) , crop_pct = 0.875 , interpolation = ' bilinear ' , num_classes = 11221 ,
@ -133,11 +142,11 @@ default_cfgs = {
class Attention ( nn . Module ) :
def __init__ ( self , dim , num_heads = 8 , qkv_bias = False , qk_scale= None , attn_drop= 0. , proj_drop = 0. ) :
def __init__ ( self , dim , num_heads = 8 , qkv_bias = False , attn_drop= 0. , proj_drop = 0. ) :
super ( ) . __init__ ( )
self . num_heads = num_heads
head_dim = dim / / num_heads
self . scale = qk_scale or head_dim * * - 0.5
self . scale = head_dim * * - 0.5
self . qkv = nn . Linear ( dim , dim * 3 , bias = qkv_bias )
self . attn_drop = nn . Dropout ( attn_drop )
@ -161,12 +170,11 @@ class Attention(nn.Module):
class Block ( nn . Module ) :
def __init__ ( self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , qk_scale= None , drop= 0. , attn_drop = 0. ,
def __init__ ( self , dim , num_heads , mlp_ratio = 4. , qkv_bias = False , drop= 0. , attn_drop = 0. ,
drop_path = 0. , act_layer = nn . GELU , norm_layer = nn . LayerNorm ) :
super ( ) . __init__ ( )
self . norm1 = norm_layer ( dim )
self . attn = Attention (
dim , num_heads = num_heads , qkv_bias = qkv_bias , qk_scale = qk_scale , attn_drop = attn_drop , proj_drop = drop )
self . attn = Attention ( dim , num_heads = num_heads , qkv_bias = qkv_bias , attn_drop = attn_drop , proj_drop = drop )
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self . drop_path = DropPath ( drop_path ) if drop_path > 0. else nn . Identity ( )
self . norm2 = norm_layer ( dim )
@ -190,7 +198,7 @@ class VisionTransformer(nn.Module):
"""
def __init__ ( self , img_size = 224 , patch_size = 16 , in_chans = 3 , num_classes = 1000 , embed_dim = 768 , depth = 12 ,
num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , qk_scale= None , representation_size= None , distilled = False ,
num_heads = 12 , mlp_ratio = 4. , qkv_bias = True , representation_size= None , distilled = False ,
drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0. , embed_layer = PatchEmbed , norm_layer = None ,
act_layer = None , weight_init = ' ' ) :
"""
@ -204,7 +212,6 @@ class VisionTransformer(nn.Module):
num_heads ( int ) : number of attention heads
mlp_ratio ( int ) : ratio of mlp hidden dim to embedding dim
qkv_bias ( bool ) : enable bias for qkv if True
qk_scale ( float ) : override default qk scale of head_dim * * - 0.5 if set
representation_size ( Optional [ int ] ) : enable and set representation layer ( pre - logits ) to this value if set
distilled ( bool ) : model includes a distillation token and head as in DeiT models
drop_rate ( float ) : dropout rate
@ -233,8 +240,8 @@ class VisionTransformer(nn.Module):
dpr = [ x . item ( ) for x in torch . linspace ( 0 , drop_path_rate , depth ) ] # stochastic depth decay rule
self . blocks = nn . Sequential ( * [
Block (
dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , qkv_bias = qkv_bias , qk_scale= qk_scal e,
drop= drop_rate , attn_drop= attn_drop_rate , drop_path = dpr [ i ] , norm_layer = norm_layer , act_layer = act_layer )
dim = embed_dim , num_heads = num_heads , mlp_ratio = mlp_ratio , qkv_bias = qkv_bias , drop= drop_rat e,
attn_drop= attn_drop_rate , drop_path = dpr [ i ] , norm_layer = norm_layer , act_layer = act_layer )
for i in range ( depth ) ] )
self . norm = norm_layer ( embed_dim )
@ -254,16 +261,17 @@ class VisionTransformer(nn.Module):
if distilled :
self . head_dist = nn . Linear ( self . embed_dim , self . num_classes ) if num_classes > 0 else nn . Identity ( )
# Weight init
assert weight_init in ( ' jax ' , ' jax_nlhb ' , ' nlhb ' , ' ' )
head_bias = - math . log ( self . num_classes ) if ' nlhb ' in weight_init else 0.
self . init_weights ( weight_init )
def init_weights ( self , mode = ' ' ) :
assert mode in ( ' jax ' , ' jax_nlhb ' , ' nlhb ' , ' ' )
head_bias = - math . log ( self . num_classes ) if ' nlhb ' in mode else 0.
trunc_normal_ ( self . pos_embed , std = .02 )
if self . dist_token is not None :
trunc_normal_ ( self . dist_token , std = .02 )
if weight_init . startswith ( ' jax ' ) :
if mode . startswith ( ' jax ' ) :
# leave cls token as zeros to match jax impl
for n , m in self . named_modules ( ) :
_init_vit_weights ( m , n , head_bias = head_bias , jax_impl = True )
named_apply ( partial ( _init_vit_weights , head_bias = head_bias , jax_impl = True ) , self )
else :
trunc_normal_ ( self . cls_token , std = .02 )
self . apply ( _init_vit_weights )
@ -272,6 +280,10 @@ class VisionTransformer(nn.Module):
# this fn left here for compat with downstream users
_init_vit_weights ( m )
@torch.jit.ignore ( )
def load_pretrained ( self , checkpoint_path , prefix = ' ' ) :
_load_weights ( self , checkpoint_path , prefix )
@torch.jit.ignore
def no_weight_decay ( self ) :
return { ' pos_embed ' , ' cls_token ' , ' dist_token ' }
@ -317,39 +329,92 @@ class VisionTransformer(nn.Module):
return x
def _init_vit_weights ( m , n : str = ' ' , head_bias : float = 0. , jax_impl : bool = False ) :
def _init_vit_weights ( m odule: nn . Module , n ame : str = ' ' , head_bias : float = 0. , jax_impl : bool = False ) :
""" ViT weight initialization
* When called without n , head_bias , jax_impl args it will behave exactly the same
as my original init for compatibility with prev hparam / downstream use cases ( ie DeiT ) .
* When called w / valid n ( module name ) and jax_impl = True , will ( hopefully ) match JAX impl
"""
if isinstance ( m , nn . Linear ) :
if n . startswith ( ' head ' ) :
nn . init . zeros_ ( m . weight )
nn . init . constant_ ( m . bias , head_bias )
elif n . startswith ( ' pre_logits ' ) :
lecun_normal_ ( m . weight )
nn . init . zeros_ ( m . bias )
if isinstance ( m odule , nn . Linear ) :
if n ame . startswith ( ' head ' ) :
nn . init . zeros_ ( m odule . weight )
nn . init . constant_ ( m odule . bias , head_bias )
elif n ame . startswith ( ' pre_logits ' ) :
lecun_normal_ ( m odule . weight )
nn . init . zeros_ ( m odule . bias )
else :
if jax_impl :
nn . init . xavier_uniform_ ( m . weight )
if m . bias is not None :
if ' mlp ' in n :
nn . init . normal_ ( m . bias , std = 1e-6 )
nn . init . xavier_uniform_ ( m odule . weight )
if m odule . bias is not None :
if ' mlp ' in n ame :
nn . init . normal_ ( m odule . bias , std = 1e-6 )
else :
nn . init . zeros_ ( m . bias )
nn . init . zeros_ ( m odule . bias )
else :
trunc_normal_ ( m . weight , std = .02 )
if m . bias is not None :
nn . init . zeros_ ( m . bias )
elif jax_impl and isinstance ( m , nn . Conv2d ) :
trunc_normal_ ( m odule . weight , std = .02 )
if m odule . bias is not None :
nn . init . zeros_ ( m odule . bias )
elif jax_impl and isinstance ( m odule , nn . Conv2d ) :
# NOTE conv was left to pytorch default in my original init
lecun_normal_ ( m . weight )
if m . bias is not None :
nn . init . zeros_ ( m . bias )
elif isinstance ( m , nn . LayerNorm ) :
nn . init . zeros_ ( m . bias )
nn . init . ones_ ( m . weight )
lecun_normal_ ( module . weight )
if module . bias is not None :
nn . init . zeros_ ( module . bias )
elif isinstance ( module , ( nn . LayerNorm , nn . GroupNorm , nn . BatchNorm2d ) ) :
nn . init . zeros_ ( module . bias )
nn . init . ones_ ( module . weight )
@torch.no_grad ( )
def _load_weights ( model : VisionTransformer , checkpoint_path : str , prefix : str = ' ' ) :
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
"""
import numpy as np
def _n2p ( w , t = True ) :
if t and w . ndim == 4 :
w = w . transpose ( [ 3 , 2 , 0 , 1 ] )
elif t and w . ndim == 3 :
w = w . transpose ( [ 2 , 0 , 1 ] )
elif t and w . ndim == 2 :
w = w . transpose ( [ 1 , 0 ] )
return torch . from_numpy ( w )
w = np . load ( checkpoint_path )
if not prefix :
prefix = ' opt/target/ ' if ' opt/target/embedding/kernel ' in w else prefix
input_conv_w = adapt_input_conv (
model . patch_embed . proj . weight . shape [ 1 ] , _n2p ( w [ f ' { prefix } embedding/kernel ' ] ) )
model . patch_embed . proj . weight . copy_ ( input_conv_w )
model . patch_embed . proj . bias . copy_ ( _n2p ( w [ f ' { prefix } embedding/bias ' ] ) )
model . cls_token . copy_ ( _n2p ( w [ f ' { prefix } cls ' ] , t = False ) )
model . pos_embed . copy_ ( _n2p ( w [ f ' { prefix } Transformer/posembed_input/pos_embedding ' ] , t = False ) )
model . norm . weight . copy_ ( _n2p ( w [ f ' { prefix } Transformer/encoder_norm/scale ' ] ) )
model . norm . bias . copy_ ( _n2p ( w [ f ' { prefix } Transformer/encoder_norm/bias ' ] ) )
if model . head . bias . shape [ 0 ] == w [ f ' { prefix } head/bias ' ] . shape [ - 1 ] :
model . head . weight . copy_ ( _n2p ( w [ f ' { prefix } head/kernel ' ] ) )
model . head . bias . copy_ ( _n2p ( w [ f ' { prefix } head/bias ' ] ) )
for i , block in enumerate ( model . blocks . children ( ) ) :
block_prefix = f ' { prefix } Transformer/encoderblock_ { i } / '
block . norm1 . weight . copy_ ( _n2p ( w [ f ' { block_prefix } LayerNorm_0/scale ' ] ) )
block . norm1 . bias . copy_ ( _n2p ( w [ f ' { block_prefix } LayerNorm_0/bias ' ] ) )
mha_prefix = block_prefix + ' MultiHeadDotProductAttention_1/ '
block . attn . qkv . weight . copy_ ( torch . cat ( [
_n2p ( w [ f ' { mha_prefix } query/kernel ' ] , t = False ) . flatten ( 1 ) . T ,
_n2p ( w [ f ' { mha_prefix } key/kernel ' ] , t = False ) . flatten ( 1 ) . T ,
_n2p ( w [ f ' { mha_prefix } value/kernel ' ] , t = False ) . flatten ( 1 ) . T ] ) )
block . attn . qkv . bias . copy_ ( torch . cat ( [
_n2p ( w [ f ' { mha_prefix } query/bias ' ] , t = False ) . reshape ( - 1 ) ,
_n2p ( w [ f ' { mha_prefix } key/bias ' ] , t = False ) . reshape ( - 1 ) ,
_n2p ( w [ f ' { mha_prefix } value/bias ' ] , t = False ) . reshape ( - 1 ) ] ) )
block . attn . proj . weight . copy_ ( _n2p ( w [ f ' { mha_prefix } out/kernel ' ] ) . flatten ( 1 ) )
block . attn . proj . bias . copy_ ( _n2p ( w [ f ' { mha_prefix } out/bias ' ] ) )
block . mlp . fc1 . weight . copy_ ( _n2p ( w [ f ' { block_prefix } MlpBlock_3/Dense_0/kernel ' ] ) )
block . mlp . fc1 . bias . copy_ ( _n2p ( w [ f ' { block_prefix } MlpBlock_3/Dense_0/bias ' ] ) )
block . mlp . fc2 . weight . copy_ ( _n2p ( w [ f ' { block_prefix } MlpBlock_3/Dense_1/kernel ' ] ) )
block . mlp . fc2 . bias . copy_ ( _n2p ( w [ f ' { block_prefix } MlpBlock_3/Dense_1/bias ' ] ) )
block . norm2 . weight . copy_ ( _n2p ( w [ f ' { block_prefix } LayerNorm_2/scale ' ] ) )
block . norm2 . bias . copy_ ( _n2p ( w [ f ' { block_prefix } LayerNorm_2/bias ' ] ) )
def resize_pos_embed ( posemb , posemb_new , num_tokens = 1 , gs_new = ( ) ) :
@ -417,23 +482,34 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
return model
@register_model
def vit_tiny_patch16_224 ( pretrained = False , * * kwargs ) :
""" ViT-Tiny (Vit-Ti/16)
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , * * kwargs )
model = _create_vision_transformer ( ' vit_tiny_patch16_224 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_small_patch16_224 ( pretrained = False , * * kwargs ) :
""" My custom ' small ' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
NOTE :
* this differs from the DeiT based ' small ' definitions with embed_dim = 384 , depth = 12 , num_heads = 6
* this model does not have a bias for QKV ( unlike the official ViT and DeiT models )
""" ViT-Small (ViT-S/16)
NOTE I ' ve replaced my previous ' small ' model definition and weights with the small variant from the DeiT paper
"""
model_kwargs = dict (
patch_size = 16 , embed_dim = 768 , depth = 8 , num_heads = 8 , mlp_ratio = 3. ,
qkv_bias = False , norm_layer = nn . LayerNorm , * * kwargs )
if pretrained :
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
model_kwargs . setdefault ( ' qk_scale ' , 768 * * - 0.5 )
model_kwargs = dict ( patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer ( ' vit_small_patch16_224 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_small_patch32_224 ( pretrained = False , * * kwargs ) :
""" ViT-Small (ViT-S/32)
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer ( ' vit_small_patch32_224 ' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_base_patch16_224 ( pretrained = False , * * kwargs ) :
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
@ -569,86 +645,86 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
@register_model
def vit_ deit_tiny_patch16_224( pretrained = False , * * kwargs ) :
def deit_tiny_patch16_224( pretrained = False , * * kwargs ) :
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , * * kwargs )
model = _create_vision_transformer ( ' vit_ deit_tiny_patch16_224' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' deit_tiny_patch16_224' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_ deit_small_patch16_224( pretrained = False , * * kwargs ) :
def deit_small_patch16_224( pretrained = False , * * kwargs ) :
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer ( ' vit_ deit_small_patch16_224' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' deit_small_patch16_224' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_ deit_base_patch16_224( pretrained = False , * * kwargs ) :
def deit_base_patch16_224( pretrained = False , * * kwargs ) :
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer ( ' vit_ deit_base_patch16_224' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' deit_base_patch16_224' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_ deit_base_patch16_384( pretrained = False , * * kwargs ) :
def deit_base_patch16_384( pretrained = False , * * kwargs ) :
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer ( ' vit_ deit_base_patch16_384' , pretrained = pretrained , * * model_kwargs )
model = _create_vision_transformer ( ' deit_base_patch16_384' , pretrained = pretrained , * * model_kwargs )
return model
@register_model
def vit_ deit_tiny_distilled_patch16_224( pretrained = False , * * kwargs ) :
def deit_tiny_distilled_patch16_224( pretrained = False , * * kwargs ) :
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 192 , depth = 12 , num_heads = 3 , * * kwargs )
model = _create_vision_transformer (
' vit_ deit_tiny_distilled_patch16_224' , pretrained = pretrained , distilled = True , * * model_kwargs )
' deit_tiny_distilled_patch16_224' , pretrained = pretrained , distilled = True , * * model_kwargs )
return model
@register_model
def vit_ deit_small_distilled_patch16_224( pretrained = False , * * kwargs ) :
def deit_small_distilled_patch16_224( pretrained = False , * * kwargs ) :
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , * * kwargs )
model = _create_vision_transformer (
' vit_ deit_small_distilled_patch16_224' , pretrained = pretrained , distilled = True , * * model_kwargs )
' deit_small_distilled_patch16_224' , pretrained = pretrained , distilled = True , * * model_kwargs )
return model
@register_model
def vit_ deit_base_distilled_patch16_224( pretrained = False , * * kwargs ) :
def deit_base_distilled_patch16_224( pretrained = False , * * kwargs ) :
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer (
' vit_ deit_base_distilled_patch16_224' , pretrained = pretrained , distilled = True , * * model_kwargs )
' deit_base_distilled_patch16_224' , pretrained = pretrained , distilled = True , * * model_kwargs )
return model
@register_model
def vit_ deit_base_distilled_patch16_384( pretrained = False , * * kwargs ) :
def deit_base_distilled_patch16_384( pretrained = False , * * kwargs ) :
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet - 1 k weights from https : / / github . com / facebookresearch / deit .
"""
model_kwargs = dict ( patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , * * kwargs )
model = _create_vision_transformer (
' vit_ deit_base_distilled_patch16_384' , pretrained = pretrained , distilled = True , * * model_kwargs )
' deit_base_distilled_patch16_384' , pretrained = pretrained , distilled = True , * * model_kwargs )
return model