@ -107,7 +107,8 @@ class Attention(nn.Module):
def forward ( self , x ) :
B , N , C = x . shape
q , k , v = self . qkv ( x ) . reshape ( B , N , 3 , self . num_heads , C / / self . num_heads ) . permute ( 2 , 0 , 3 , 1 , 4 )
qkv = self . qkv ( x ) . reshape ( B , N , 3 , self . num_heads , C / / self . num_heads ) . permute ( 2 , 0 , 3 , 1 , 4 )
q , k , v = qkv [ 0 ] , qkv [ 1 ] , qkv [ 2 ] # make torchscript happy (cannot use tensor as tuple)
attn = ( q @ k . transpose ( - 2 , - 1 ) ) * self . scale
attn = attn . softmax ( dim = - 1 )
@ -204,6 +205,9 @@ class VisionTransformer(nn.Module):
num_heads = 12 , mlp_ratio = 4. , qkv_bias = False , qk_scale = None , drop_rate = 0. , attn_drop_rate = 0. ,
drop_path_rate = 0. , hybrid_backbone = None , norm_layer = nn . LayerNorm ) :
super ( ) . __init__ ( )
self . num_classes = num_classes
self . embed_dim = embed_dim
if hybrid_backbone is not None :
self . patch_embed = HybridEmbed (
hybrid_backbone , img_size = img_size , in_chans = in_chans , embed_dim = embed_dim )
@ -229,7 +233,7 @@ class VisionTransformer(nn.Module):
#self.repr_act = nn.Tanh()
# Classifier head
self . head = nn . Linear ( embed_dim , num_classes )
self . head = nn . Linear ( embed_dim , num_classes ) if num_classes > 0 else nn . Identity ( )
trunc_normal_ ( self . pos_embed , std = .02 )
trunc_normal_ ( self . cls_token , std = .02 )
@ -244,11 +248,18 @@ class VisionTransformer(nn.Module):
nn . init . constant_ ( m . bias , 0 )
nn . init . constant_ ( m . weight , 1.0 )
@ property
@ torch.jit.ignore
def no_weight_decay ( self ) :
return { ' pos_embed ' , ' cls_token ' }
def forward ( self , x ) :
def get_classifier ( self ) :
return self . head
def reset_classifier ( self , num_classes , global_pool = ' ' ) :
self . num_classes = num_classes
self . head = nn . Linear ( self . embed_dim , num_classes ) if num_classes > 0 else nn . Identity ( )
def forward_features ( self , x ) :
B = x . shape [ 0 ]
x = self . patch_embed ( x )
@ -261,7 +272,11 @@ class VisionTransformer(nn.Module):
x = blk ( x )
x = self . norm ( x )
x = self . head ( x [ : , 0 ] )
return x [ : , 0 ]
def forward ( self , x ) :
x = self . forward_features ( x )
x = self . head ( x )
return x
@ -284,7 +299,7 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
model . default_cfg = default_cfgs [ ' vit_small_patch16_224 ' ]
if pretrained :
load_pretrained (
model , num_classes = kwargs. get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) , filter_fn = _conv_filter )
model , num_classes = model. num_classes , in_chans = kwargs . get ( ' in_chans ' , 3 ) , filter_fn = _conv_filter )
return model
@ -297,7 +312,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
model . default_cfg = default_cfgs [ ' vit_base_patch16_224 ' ]
if pretrained :
load_pretrained (
model , num_classes = kwargs. get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) , filter_fn = _conv_filter )
model , num_classes = model. num_classes , in_chans = kwargs . get ( ' in_chans ' , 3 ) , filter_fn = _conv_filter )
return model
@ -308,8 +323,7 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
model . default_cfg = default_cfgs [ ' vit_base_patch16_384 ' ]
if pretrained :
load_pretrained (
model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
load_pretrained ( model , num_classes = model . num_classes , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
return model
@ -320,8 +334,7 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
model . default_cfg = default_cfgs [ ' vit_base_patch32_384 ' ]
if pretrained :
load_pretrained (
model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
load_pretrained ( model , num_classes = model . num_classes , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
return model
@ -339,8 +352,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
model . default_cfg = default_cfgs [ ' vit_large_patch16_384 ' ]
if pretrained :
load_pretrained (
model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
load_pretrained ( model , num_classes = model . num_classes , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
return model
@ -351,8 +363,7 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , * * kwargs )
model . default_cfg = default_cfgs [ ' vit_large_patch32_384 ' ]
if pretrained :
load_pretrained (
model , num_classes = kwargs . get ( ' num_classes ' , 0 ) , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
load_pretrained ( model , num_classes = model . num_classes , in_chans = kwargs . get ( ' in_chans ' , 3 ) )
return model