@ -34,6 +34,7 @@ from typing import Tuple, Optional, List, Union, Any, Type
import torch
import torch . nn as nn
import torch . nn . functional as F
import torch . utils . checkpoint as checkpoint
from timm . data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
@ -41,7 +42,7 @@ from .fx_features import register_notrace_function
from . helpers import build_model_with_cfg , named_apply
from . layers import DropPath , Mlp , to_2tuple , _assert
from . registry import register_model
from . vision_transformer import checkpoint_filter_fn
_logger = logging . getLogger ( __name__ )
@ -186,12 +187,13 @@ class WindowMultiHeadAttention(nn.Module):
act_layer = nn . ReLU ,
drop = ( 0.125 , 0. ) # FIXME should there be stochasticity, appears to 'overfit' without?
)
self . register_parameter ( " tau " , torch . nn . Parameter ( torch . ones ( num_heads ) ) )
# NOTE old checkpoints used inverse of logit_scale ('tau') following the paper, see conversion fn
self . logit_scale = nn . Parameter ( torch . log ( 10 * torch . ones ( num_heads ) ) )
self . _make_pair_wise_relative_positions ( )
def _make_pair_wise_relative_positions ( self ) - > None :
""" Method initializes the pair-wise relative positions to compute the positional biases. """
device = self . tau . device
device = self . logit_scale . device
coordinates = torch . stack ( torch . meshgrid ( [
torch . arange ( self . window_size [ 0 ] , device = device ) ,
torch . arange ( self . window_size [ 1 ] , device = device ) ] ) , dim = 0 ) . flatten ( 1 )
@ -250,10 +252,11 @@ class WindowMultiHeadAttention(nn.Module):
query , key , value = qkv . unbind ( 0 )
# compute attention map with scaled cosine attention
denom = torch . norm ( query , dim = - 1 , keepdim = True ) @ torch . norm ( key , dim = - 1 , keepdim = True ). transpose ( - 2 , - 1 )
attn = query @ key . transpose ( - 2 , - 1 ) / denom . clamp ( min = 1e-6 )
attn = attn / self . tau . clamp ( min = 0.01 ) . reshape ( 1 , self . num_heads , 1 , 1 )
attn = ( F . normalize ( query , dim = - 1 ) @ F . normalize ( key , dim = - 1 ). transpose ( - 2 , - 1 ) )
logit_scale = torch . clamp ( self . logit_scale . reshape ( 1 , self . num_heads , 1 , 1 ) , max = math . log ( 1. / 0.01 ) ) . exp ( )
attn = attn * logit_scale
attn = attn + self . _relative_positional_encodings ( )
if mask is not None :
# Apply mask if utilized
num_win : int = mask . shape [ 0 ]
@ -309,7 +312,7 @@ class SwinTransformerBlock(nn.Module):
window_size : Tuple [ int , int ] ,
shift_size : Tuple [ int , int ] = ( 0 , 0 ) ,
mlp_ratio : float = 4.0 ,
init_values : float = 0 ,
init_values : Optional [ float ] = 0 ,
drop : float = 0.0 ,
drop_attn : float = 0.0 ,
drop_path : float = 0.0 ,
@ -323,7 +326,7 @@ class SwinTransformerBlock(nn.Module):
self . target_shift_size : Tuple [ int , int ] = to_2tuple ( shift_size )
self . window_size , self . shift_size = self . _calc_window_shift ( to_2tuple ( window_size ) )
self . window_area = self . window_size [ 0 ] * self . window_size [ 1 ]
self . init_values : float = init_values
self . init_values : Optional [ float ] = init_values
# attn branch
self . attn = WindowMultiHeadAttention (
@ -387,7 +390,7 @@ class SwinTransformerBlock(nn.Module):
def init_weights ( self ) :
# extra, module specific weight init
if self . init_values :
if self . init_values is not None :
nn . init . constant_ ( self . norm1 . weight , self . init_values )
nn . init . constant_ ( self . norm2 . weight , self . init_values )
@ -536,7 +539,7 @@ class SwinTransformerStage(nn.Module):
feat_size : Tuple [ int , int ] ,
window_size : Tuple [ int , int ] ,
mlp_ratio : float = 4.0 ,
init_values : float = 0.0 ,
init_values : Optional [ float ] = 0.0 ,
drop : float = 0.0 ,
drop_attn : float = 0.0 ,
drop_path : Union [ List [ float ] , float ] = 0.0 ,
@ -650,7 +653,7 @@ class SwinTransformerV2Cr(nn.Module):
depths : Tuple [ int , . . . ] = ( 2 , 2 , 6 , 2 ) ,
num_heads : Tuple [ int , . . . ] = ( 3 , 6 , 12 , 24 ) ,
mlp_ratio : float = 4.0 ,
init_values : float = 0. 0 ,
init_values : Optional [ float ] = 0. ,
drop_rate : float = 0.0 ,
attn_drop_rate : float = 0.0 ,
drop_path_rate : float = 0.0 ,
@ -808,6 +811,21 @@ def init_weights(module: nn.Module, name: str = ''):
module . init_weights ( )
def checkpoint_filter_fn ( state_dict , model ) :
""" convert patch embedding weight from manual patchify + linear proj to conv """
out_dict = { }
if ' model ' in state_dict :
# For deit models
state_dict = state_dict [ ' model ' ]
for k , v in state_dict . items ( ) :
if ' tau ' in k :
# convert old tau based checkpoints -> logit_scale (inverse)
v = torch . log ( 1 / v )
k = k . replace ( ' tau ' , ' logit_scale ' )
out_dict [ k ] = v
return out_dict
def _create_swin_transformer_v2_cr ( variant , pretrained = False , * * kwargs ) :
if kwargs . get ( ' features_only ' , None ) :
raise RuntimeError ( ' features_only not implemented for Vision Transformer models. ' )
@ -890,7 +908,6 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs):
embed_dim = 96 ,
depths = ( 2 , 2 , 18 , 2 ) ,
num_heads = ( 3 , 6 , 12 , 24 ) ,
init_values = 1e-5 ,
extra_norm_stage = True ,
* * kwargs
)
@ -928,7 +945,6 @@ def swinv2_cr_base_ns_224(pretrained=False, **kwargs):
embed_dim = 128 ,
depths = ( 2 , 2 , 18 , 2 ) ,
num_heads = ( 4 , 8 , 16 , 32 ) ,
init_values = 1e-6 ,
extra_norm_stage = True ,
* * kwargs
)