@ -12,6 +12,7 @@ This implementation is experimental and subject to change in manners that will b
GitHub link above . It needs further investigation as throughput vs mem tradeoff doesn ' t appear beneficial.
* num_heads per stage is not detailed for Huge and Giant model variants
* ' Giant ' is 3 B params in paper but ~ 2.6 B here despite matching paper dim + block counts
* experiments are ongoing wrt to ' main branch ' norm layer use and weight init scheme
Noteworthy additions over official Swin v1 :
* MLP relative position embedding is looking promising and adapts to different image / window sizes
@ -37,7 +38,7 @@ import torch.utils.checkpoint as checkpoint
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
from . fx_features import register_notrace_function
from . helpers import build_model_with_cfg , overlay_external_default_cfg, named_apply
from . helpers import build_model_with_cfg , named_apply
from . layers import DropPath , Mlp , to_2tuple , _assert
from . registry import register_model
from . vision_transformer import checkpoint_filter_fn
@ -67,27 +68,29 @@ default_cfgs = {
' swin_v2_cr_tiny_384 ' : _cfg (
url = " " , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' swin_v2_cr_tiny_224 ' : _cfg (
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 1.0 ) ,
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 0.9 ) ,
' swin_v2_cr_tiny_ns_224 ' : _cfg (
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 0.9 ) ,
' swin_v2_cr_small_384 ' : _cfg (
url = " " , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' swin_v2_cr_small_224 ' : _cfg (
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 1. 0) ,
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 0.9 ) ,
' swin_v2_cr_base_384 ' : _cfg (
url = " " , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' swin_v2_cr_base_224 ' : _cfg (
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 1. 0) ,
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 0.9 ) ,
' swin_v2_cr_large_384 ' : _cfg (
url = " " , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' swin_v2_cr_large_224 ' : _cfg (
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 1. 0) ,
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 0.9 ) ,
' swin_v2_cr_huge_384 ' : _cfg (
url = " " , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' swin_v2_cr_huge_224 ' : _cfg (
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 1. 0) ,
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 0.9 ) ,
' swin_v2_cr_giant_384 ' : _cfg (
url = " " , input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 ) ,
' swin_v2_cr_giant_224 ' : _cfg (
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 1. 0) ,
url = " " , input_size = ( 3 , 224 , 224 ) , crop_pct = 0.9 ) ,
}
@ -175,7 +178,7 @@ class WindowMultiHeadAttention(nn.Module):
hidden_features = meta_hidden_dim ,
out_features = num_heads ,
act_layer = nn . ReLU ,
drop = 0. # FIXME should we add stochasticity ?
drop = 0. 1 # FIXME should there be stochasticity, appears to 'overfit' without ?
)
self . register_parameter ( " tau " , torch . nn . Parameter ( torch . ones ( num_heads ) ) )
self . _make_pair_wise_relative_positions ( )
@ -336,7 +339,8 @@ class SwinTransformerBlock(nn.Module):
self . norm2 = norm_layer ( dim )
self . drop_path2 = DropPath ( drop_prob = drop_path ) if drop_path > 0.0 else nn . Identity ( )
# extra norm layer mentioned for Huge/Giant models in V2 paper (FIXME may be in wrong spot?)
# Extra main branch norm layer mentioned for Huge/Giant models in V2 paper.
# Also being used as final network norm and optional stage ending norm while still in a C-last format.
self . norm3 = norm_layer ( dim ) if extra_norm else nn . Identity ( )
self . _make_attention_mask ( )
@ -392,13 +396,16 @@ class SwinTransformerBlock(nn.Module):
x = x . view ( B , H , W , C )
# cyclic shift
if any ( self . shift_size ) :
shifted_x = torch . roll ( x , shifts = ( - self . shift_size [ 0 ] , - self . shift_size [ 1 ] ) , dims = ( 1 , 2 ) )
else :
shifted_x = x
sh , sw = self . shift_size
do_shift : bool = any ( self . shift_size )
if do_shift :
# FIXME PyTorch XLA needs cat impl, roll not lowered
# x = torch.cat([x[:, sh:], x[:, :sh]], dim=1)
# x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2)
x = torch . roll ( x , shifts = ( - sh , - sw ) , dims = ( 1 , 2 ) )
# partition windows
x_windows = window_partition ( shifted_x , self . window_size ) # num_windows * B, window_size, window_size, C
x_windows = window_partition ( x, self . window_size ) # num_windows * B, window_size, window_size, C
x_windows = x_windows . view ( - 1 , self . window_size [ 0 ] * self . window_size [ 1 ] , C )
# W-MSA/SW-MSA
@ -406,13 +413,14 @@ class SwinTransformerBlock(nn.Module):
# merge windows
attn_windows = attn_windows . view ( - 1 , self . window_size [ 0 ] , self . window_size [ 1 ] , C )
shifted_ x = window_reverse ( attn_windows , self . window_size , self . feat_size ) # B H' W' C
x = window_reverse ( attn_windows , self . window_size , self . feat_size ) # B H' W' C
# reverse cyclic shift
if any ( self . shift_size ) :
x = torch . roll ( shifted_x , shifts = self . shift_size , dims = ( 1 , 2 ) )
else :
x = shifted_x
if do_shift :
# FIXME PyTorch XLA needs cat impl, roll not lowered
# x = torch.cat([x[:, -sh:], x[:, :-sh]], dim=1)
# x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)
x = torch . roll ( x , shifts = ( sh , sw ) , dims = ( 1 , 2 ) )
x = x . view ( B , L , C )
return x
@ -429,7 +437,7 @@ class SwinTransformerBlock(nn.Module):
# NOTE post-norm branches (op -> norm -> drop)
x = x + self . drop_path1 ( self . norm1 ( self . _shifted_window_attn ( x ) ) )
x = x + self . drop_path2 ( self . norm2 ( self . mlp ( x ) ) )
x = self . norm3 ( x ) # main-branch norm enabled for some blocks (every 6 for Huge/Giant)
x = self . norm3 ( x ) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
return x
@ -452,8 +460,10 @@ class PatchMerging(nn.Module):
Returns :
output ( torch . Tensor ) : Output tensor of the shape [ B , 2 * C , H / / 2 , W / / 2 ]
"""
x = bchw_to_bhwc ( x ) . unfold ( dimension = 1 , size = 2 , step = 2 ) . unfold ( dimension = 2 , size = 2 , step = 2 )
x = x . permute ( 0 , 1 , 2 , 5 , 4 , 3 ) . flatten ( 3 ) # permute maintains compat with ch order in official swin impl
B , C , H , W = x . shape
# unfold + BCHW -> BHWC together
# ordering, 5, 3, 1 instead of 3, 5, 1 maintains compat with original swin v1 merge
x = x . reshape ( B , C , H / / 2 , 2 , W / / 2 , 2 ) . permute ( 0 , 2 , 4 , 5 , 3 , 1 ) . flatten ( 3 )
x = self . norm ( x )
x = bhwc_to_bchw ( self . reduction ( x ) )
return x
@ -497,8 +507,8 @@ class SwinTransformerStage(nn.Module):
drop_attn ( float ) : Dropout rate of attention map
drop_path ( float ) : Dropout in main path
norm_layer ( Type [ nn . Module ] ) : Type of normalization layer to be utilized . Default : nn . LayerNorm
grad_checkpointing ( bool ) : If true checkpointing is utilized
extra_norm_period ( int ) : Insert extra norm layer on main branch every N ( period ) blocks
extra_norm_stage ( bool ) : End each stage with an extra norm layer in main branch
sequential_attn ( bool ) : If true sequential self - attention is performed
"""
@ -515,17 +525,23 @@ class SwinTransformerStage(nn.Module):
drop_attn : float = 0.0 ,
drop_path : Union [ List [ float ] , float ] = 0.0 ,
norm_layer : Type [ nn . Module ] = nn . LayerNorm ,
grad_checkpointing : bool = False ,
extra_norm_period : int = 0 ,
extra_norm_stage : bool = False ,
sequential_attn : bool = False ,
) - > None :
super ( SwinTransformerStage , self ) . __init__ ( )
self . downscale : bool = downscale
self . grad_checkpointing : bool = grad_checkpointing
self . grad_checkpointing : bool = False
self . feat_size : Tuple [ int , int ] = ( feat_size [ 0 ] / / 2 , feat_size [ 1 ] / / 2 ) if downscale else feat_size
self . downsample = PatchMerging ( embed_dim , norm_layer = norm_layer ) if downscale else nn . Identity ( )
def _extra_norm ( index ) :
i = index + 1
if extra_norm_period and i % extra_norm_period == 0 :
return True
return i == depth if extra_norm_stage else False
embed_dim = embed_dim * 2 if downscale else embed_dim
self . blocks = nn . Sequential ( * [
SwinTransformerBlock (
@ -538,7 +554,7 @@ class SwinTransformerStage(nn.Module):
drop = drop ,
drop_attn = drop_attn ,
drop_path = drop_path [ index ] if isinstance ( drop_path , list ) else drop_path ,
extra_norm = not ( index + 1 ) % extra_norm_period if extra_norm_period else False ,
extra_norm = _extra_norm ( index ) ,
sequential_attn = sequential_attn ,
norm_layer = norm_layer ,
)
@ -600,9 +616,9 @@ class SwinTransformerV2Cr(nn.Module):
attn_drop_rate ( float ) : Dropout rate of attention map . Default : 0.0
drop_path_rate ( float ) : Stochastic depth rate . Default : 0.0
norm_layer ( Type [ nn . Module ] ) : Type of normalization layer to be utilized . Default : nn . LayerNorm
grad_checkpointing ( bool ) : If true checkpointing is utilized . Default : False
extra_norm_period ( int ) : Insert extra norm layer on main branch every N ( period ) blocks in stage
extra_norm_stage ( bool ) : End each stage with an extra norm layer in main branch
sequential_attn ( bool ) : If true sequential self - attention is performed . Default : False
use_deformable ( bool ) : If true deformable block is used . Default : False
"""
def __init__ (
@ -621,10 +637,11 @@ class SwinTransformerV2Cr(nn.Module):
attn_drop_rate : float = 0.0 ,
drop_path_rate : float = 0.0 ,
norm_layer : Type [ nn . Module ] = nn . LayerNorm ,
grad_checkpointing : bool = False ,
extra_norm_period : int = 0 ,
extra_norm_stage : bool = False ,
sequential_attn : bool = False ,
global_pool : str = ' avg ' ,
weight_init = ' skip ' ,
* * kwargs : Any
) - > None :
super ( SwinTransformerV2Cr , self ) . __init__ ( )
@ -638,7 +655,7 @@ class SwinTransformerV2Cr(nn.Module):
self . window_size : int = window_size
self . num_features : int = int ( embed_dim * 2 * * ( len ( depths ) - 1 ) )
self . patch_embed : nn . Module = PatchEmbed (
self . patch_embed = PatchEmbed (
img_size = img_size , patch_size = patch_size , in_chans = in_chans ,
embed_dim = embed_dim , norm_layer = norm_layer )
patch_grid_size : Tuple [ int , int ] = self . patch_embed . grid_size
@ -659,8 +676,8 @@ class SwinTransformerV2Cr(nn.Module):
drop = drop_rate ,
drop_attn = attn_drop_rate ,
drop_path = drop_path_rate [ sum ( depths [ : index ] ) : sum ( depths [ : index + 1 ] ) ] ,
grad_checkpointing = grad_checkpointing ,
extra_norm_period = extra_norm_period ,
extra_norm_stage = extra_norm_stage or ( index + 1 ) == len ( depths ) , # last stage ends w/ norm
sequential_attn = sequential_attn ,
norm_layer = norm_layer ,
)
@ -668,12 +685,12 @@ class SwinTransformerV2Cr(nn.Module):
self . stages = nn . Sequential ( * stages )
self . global_pool : str = global_pool
self . head : nn . Module = nn . Linear (
in_features = self . num_features , out_features = num_classes ) if num_classes else nn . Identity ( )
self . head = nn . Linear ( self . num_features , num_classes ) if num_classes else nn . Identity ( )
# FIXME weight init TBD, PyTorch default init appears to be working well,
# but differs from usual ViT or Swin init.
# named_apply(init_weights, self)
# current weight init skips custom init and uses pytorch layer defaults, seems to work well
# FIXME more experiments needed
if weight_init != ' skip ' :
named_apply ( init_weights , self )
def update_input_size (
self ,
@ -704,13 +721,28 @@ class SwinTransformerV2Cr(nn.Module):
new_img_size = ( new_patch_grid_size [ 0 ] / / stage_scale , new_patch_grid_size [ 1 ] / / stage_scale ) ,
)
@torch.jit.ignore
def group_matcher ( self , coarse = False ) :
return dict (
stem = r ' ^patch_embed ' , # stem and embed
blocks = r ' ^stages \ .( \ d+) ' if coarse else [
( r ' ^stages \ .( \ d+).downsample ' , ( 0 , ) ) ,
( r ' ^stages \ .( \ d+) \ . \ w+ \ .( \ d+) ' , None ) ,
]
)
@torch.jit.ignore
def set_grad_checkpointing ( self , enable = True ) :
for s in self . stages :
s . grad_checkpointing = enable
@torch.jit.ignore ( )
def get_classifier ( self ) - > nn . Module :
""" Method returns the classification head of the model.
Returns :
head ( nn . Module ) : Current classification head
"""
head : nn . Module = self . head
return head
return self . head
def reset_classifier ( self , num_classes : int , global_pool : Optional [ str ] = None ) - > None :
""" Method results the classification head
@ -722,8 +754,7 @@ class SwinTransformerV2Cr(nn.Module):
self . num_classes : int = num_classes
if global_pool is not None :
self . global_pool = global_pool
self . head : nn . Module = nn . Linear (
in_features = self . num_features , out_features = num_classes ) if num_classes > 0 else nn . Identity ( )
self . head = nn . Linear ( self . num_features , num_classes ) if num_classes > 0 else nn . Identity ( )
def forward_features ( self , x : torch . Tensor ) - > torch . Tensor :
x = self . patch_embed ( x )
@ -742,41 +773,28 @@ class SwinTransformerV2Cr(nn.Module):
def init_weights ( module : nn . Module , name : str = ' ' ) :
# FIXME WIP
# FIXME WIP determining if there's a better weight init
if isinstance ( module , nn . Linear ) :
if ' qkv ' in name :
# treat the weights of Q, K, V separately
val = math . sqrt ( 6. / float ( module . weight . shape [ 0 ] / / 3 + module . weight . shape [ 1 ] ) )
nn . init . uniform_ ( module . weight , - val , val )
elif ' head ' in name :
nn . init . zeros_ ( module . weight )
else :
nn . init . xavier_uniform_ ( module . weight )
if module . bias is not None :
nn . init . zeros_ ( module . bias )
def _create_swin_transformer_v2_cr ( variant , pretrained = False , default_cfg = None , * * kwargs ) :
if default_cfg is None :
default_cfg = deepcopy ( default_cfgs [ variant ] )
overlay_external_default_cfg ( default_cfg , kwargs )
default_num_classes = default_cfg [ ' num_classes ' ]
default_img_size = default_cfg [ ' input_size ' ] [ - 2 : ]
num_classes = kwargs . pop ( ' num_classes ' , default_num_classes )
img_size = kwargs . pop ( ' img_size ' , default_img_size )
def _create_swin_transformer_v2_cr ( variant , pretrained = False , * * kwargs ) :
if kwargs . get ( ' features_only ' , None ) :
raise RuntimeError ( ' features_only not implemented for Vision Transformer models. ' )
model = build_model_with_cfg (
SwinTransformerV2Cr ,
variant ,
pretrained ,
default_cfg = default_cfg ,
img_size = img_size ,
num_classes = num_classes ,
SwinTransformerV2Cr , variant , pretrained ,
pretrained_filter_fn = checkpoint_filter_fn ,
* * kwargs
)
return model
@ -804,6 +822,21 @@ def swin_v2_cr_tiny_224(pretrained=False, **kwargs):
return _create_swin_transformer_v2_cr ( ' swin_v2_cr_tiny_224 ' , pretrained = pretrained , * * model_kwargs )
@register_model
def swin_v2_cr_tiny_ns_224 ( pretrained = False , * * kwargs ) :
""" Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms.
* * Experimental , may make default if results are improved . * *
"""
model_kwargs = dict (
embed_dim = 96 ,
depths = ( 2 , 2 , 6 , 2 ) ,
num_heads = ( 3 , 6 , 12 , 24 ) ,
extra_norm_stage = True ,
* * kwargs
)
return _create_swin_transformer_v2_cr ( ' swin_v2_cr_tiny_ns_224 ' , pretrained = pretrained , * * model_kwargs )
@register_model
def swin_v2_cr_small_384 ( pretrained = False , * * kwargs ) :
""" Swin-S V2 CR @ 384x384, trained ImageNet-1k """