@ -40,7 +40,7 @@ from .vision_transformer import Mlp, Block
def _cfg ( url = ' ' , * * kwargs ) :
return {
' url ' : url ,
' num_classes ' : 1000 , ' input_size ' : ( 3 , 240 , 240 ) , ' pool_size ' : None ,
' num_classes ' : 1000 , ' input_size ' : ( 3 , 240 , 240 ) , ' pool_size ' : None , ' crop_pct ' : 0.875 ,
' mean ' : IMAGENET_DEFAULT_MEAN , ' std ' : IMAGENET_DEFAULT_STD , ' fixed_input_size ' : True ,
' first_conv ' : ( ' patch_embed.0.proj ' , ' patch_embed.1.proj ' ) ,
' classifier ' : ( ' head.0 ' , ' head.1 ' ) ,
@ -56,7 +56,7 @@ default_cfgs = {
) ,
' crossvit_15_dagger_408 ' : _cfg (
url = ' https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth ' ,
input_size = ( 3 , 408 , 408 ) , first_conv = ( ' patch_embed.0.proj.0 ' , ' patch_embed.1.proj.0 ' ) ,
input_size = ( 3 , 408 , 408 ) , first_conv = ( ' patch_embed.0.proj.0 ' , ' patch_embed.1.proj.0 ' ) , crop_pct = 1.0 ,
) ,
' crossvit_18_240 ' : _cfg ( url = ' https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth ' ) ,
' crossvit_18_dagger_240 ' : _cfg (
@ -65,7 +65,7 @@ default_cfgs = {
) ,
' crossvit_18_dagger_408 ' : _cfg (
url = ' https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth ' ,
input_size = ( 3 , 408 , 408 ) , first_conv = ( ' patch_embed.0.proj.0 ' , ' patch_embed.1.proj.0 ' ) ,
input_size = ( 3 , 408 , 408 ) , first_conv = ( ' patch_embed.0.proj.0 ' , ' patch_embed.1.proj.0 ' ) , crop_pct = 1.0 ,
) ,
' crossvit_9_240 ' : _cfg ( url = ' https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth ' ) ,
' crossvit_9_dagger_240 ' : _cfg (
@ -263,7 +263,7 @@ class CrossViT(nn.Module):
self , img_size = 224 , img_scale = ( 1.0 , 1.0 ) , patch_size = ( 8 , 16 ) , in_chans = 3 , num_classes = 1000 ,
embed_dim = ( 192 , 384 ) , depth = ( ( 1 , 3 , 1 ) , ( 1 , 3 , 1 ) , ( 1 , 3 , 1 ) ) , num_heads = ( 6 , 12 ) , mlp_ratio = ( 2. , 2. , 4. ) ,
qkv_bias = True , drop_rate = 0. , attn_drop_rate = 0. , drop_path_rate = 0. ,
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , multi_conv = False
norm_layer = partial ( nn . LayerNorm , eps = 1e-6 ) , multi_conv = False , crop_scale = False ,
) :
super ( ) . __init__ ( )
@ -271,6 +271,7 @@ class CrossViT(nn.Module):
self . img_size = to_2tuple ( img_size )
img_scale = to_2tuple ( img_scale )
self . img_size_scaled = [ tuple ( [ int ( sj * si ) for sj in self . img_size ] ) for si in img_scale ]
self . crop_scale = crop_scale # crop instead of interpolate for scale
num_patches = _compute_num_patches ( self . img_size_scaled , patch_size )
self . num_branches = len ( patch_size )
self . embed_dim = embed_dim
@ -307,8 +308,7 @@ class CrossViT(nn.Module):
for i in range ( self . num_branches ) ] )
for i in range ( self . num_branches ) :
if hasattr ( self , f ' pos_embed_ { i } ' ) :
trunc_normal_ ( getattr ( self , f ' pos_embed_ { i } ' ) , std = .02 )
trunc_normal_ ( getattr ( self , f ' pos_embed_ { i } ' ) , std = .02 )
trunc_normal_ ( getattr ( self , f ' cls_token_ { i } ' ) , std = .02 )
self . apply ( self . _init_weights )
@ -324,9 +324,12 @@ class CrossViT(nn.Module):
@torch.jit.ignore
def no_weight_decay ( self ) :
out = { ' cls_token ' }
if self . pos_embed [ 0 ] . requires_grad :
out . add ( ' pos_embed ' )
out = set ( )
for i in range ( self . num_branches ) :
out . add ( f ' cls_token_ { i } ' )
pe = getattr ( self , f ' pos_embed_ { i } ' , None )
if pe is not None and pe . requires_grad :
out . add ( f ' pos_embed_ { i } ' )
return out
def get_classifier ( self ) :
@ -342,23 +345,29 @@ class CrossViT(nn.Module):
B , C , H , W = x . shape
xs = [ ]
for i , patch_embed in enumerate ( self . patch_embed ) :
x_ = x
ss = self . img_size_scaled [ i ]
x_ = torch . nn . functional . interpolate ( x , size = ss , mode = ' bicubic ' , align_corners = False ) if H != ss [ 0 ] else x
tmp = patch_embed ( x_ )
if H != ss [ 0 ] or W != ss [ 1 ] :
if self . crop_scale and ss [ 0 ] < = H and ss [ 1 ] < = W :
cu , cl = int ( round ( ( H - ss [ 0 ] ) / 2. ) ) , int ( round ( ( W - ss [ 1 ] ) / 2. ) )
x_ = x_ [ : , : , cu : cu + ss [ 0 ] , cl : cl + ss [ 1 ] ]
else :
x_ = torch . nn . functional . interpolate ( x_ , size = ss , mode = ' bicubic ' , align_corners = False )
x_ = patch_embed ( x_ )
cls_tokens = self . cls_token_0 if i == 0 else self . cls_token_1 # hard-coded for torch jit script
cls_tokens = cls_tokens . expand ( B , - 1 , - 1 )
tmp = torch . cat ( ( cls_tokens , tmp ) , dim = 1 )
x_ = torch . cat ( ( cls_tokens , x_ ) , dim = 1 )
pos_embed = self . pos_embed_0 if i == 0 else self . pos_embed_1 # hard-coded for torch jit script
tmp = tmp + pos_embed
tmp = self . pos_drop ( tmp )
xs . append ( tmp )
x_ = x_ + pos_embed
x_ = self . pos_drop ( x_ )
xs . append ( x_ )
for i , blk in enumerate ( self . blocks ) :
xs = blk ( xs )
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
xs = [ norm ( xs [ i ] ) for i , norm in enumerate ( self . norm ) ]
return [ x [ : , 0 ] for x in xs ]
return [ x o [ : , 0 ] for x o in xs ]
def forward ( self , x ) :
xs = self . forward_features ( x )