|
|
|
@ -76,10 +76,15 @@ default_cfgs = {
|
|
|
|
|
'swin_v2_cr_small_224': _cfg(
|
|
|
|
|
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
|
|
|
|
|
input_size=(3, 224, 224), crop_pct=0.9),
|
|
|
|
|
'swin_v2_cr_small_ns_224': _cfg(
|
|
|
|
|
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
|
|
|
|
|
input_size=(3, 224, 224), crop_pct=0.9),
|
|
|
|
|
'swin_v2_cr_base_384': _cfg(
|
|
|
|
|
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'swin_v2_cr_base_224': _cfg(
|
|
|
|
|
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
|
|
|
|
'swin_v2_cr_base_ns_224': _cfg(
|
|
|
|
|
url="", input_size=(3, 224, 224), crop_pct=0.9),
|
|
|
|
|
'swin_v2_cr_large_384': _cfg(
|
|
|
|
|
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'swin_v2_cr_large_224': _cfg(
|
|
|
|
@ -179,7 +184,7 @@ class WindowMultiHeadAttention(nn.Module):
|
|
|
|
|
hidden_features=meta_hidden_dim,
|
|
|
|
|
out_features=num_heads,
|
|
|
|
|
act_layer=nn.ReLU,
|
|
|
|
|
drop=0.1 # FIXME should there be stochasticity, appears to 'overfit' without?
|
|
|
|
|
drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without?
|
|
|
|
|
)
|
|
|
|
|
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads)))
|
|
|
|
|
self._make_pair_wise_relative_positions()
|
|
|
|
@ -304,6 +309,7 @@ class SwinTransformerBlock(nn.Module):
|
|
|
|
|
window_size: Tuple[int, int],
|
|
|
|
|
shift_size: Tuple[int, int] = (0, 0),
|
|
|
|
|
mlp_ratio: float = 4.0,
|
|
|
|
|
init_values: float = 0,
|
|
|
|
|
drop: float = 0.0,
|
|
|
|
|
drop_attn: float = 0.0,
|
|
|
|
|
drop_path: float = 0.0,
|
|
|
|
@ -317,6 +323,7 @@ class SwinTransformerBlock(nn.Module):
|
|
|
|
|
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
|
|
|
|
|
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
|
|
|
|
|
self.window_area = self.window_size[0] * self.window_size[1]
|
|
|
|
|
self.init_values: float = init_values
|
|
|
|
|
|
|
|
|
|
# attn branch
|
|
|
|
|
self.attn = WindowMultiHeadAttention(
|
|
|
|
@ -345,6 +352,7 @@ class SwinTransformerBlock(nn.Module):
|
|
|
|
|
self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()
|
|
|
|
|
|
|
|
|
|
self._make_attention_mask()
|
|
|
|
|
self.init_weights()
|
|
|
|
|
|
|
|
|
|
def _calc_window_shift(self, target_window_size):
|
|
|
|
|
window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)]
|
|
|
|
@ -377,6 +385,12 @@ class SwinTransformerBlock(nn.Module):
|
|
|
|
|
attn_mask = None
|
|
|
|
|
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
|
|
|
|
|
|
|
|
|
def init_weights(self):
|
|
|
|
|
# extra, module specific weight init
|
|
|
|
|
if self.init_values:
|
|
|
|
|
nn.init.constant_(self.norm1.weight, self.init_values)
|
|
|
|
|
nn.init.constant_(self.norm2.weight, self.init_values)
|
|
|
|
|
|
|
|
|
|
def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None:
|
|
|
|
|
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
|
|
|
|
|
|
|
|
|
@ -435,7 +449,7 @@ class SwinTransformerBlock(nn.Module):
|
|
|
|
|
Returns:
|
|
|
|
|
output (torch.Tensor): Output tensor of the shape [B, C, H, W]
|
|
|
|
|
"""
|
|
|
|
|
# NOTE post-norm branches (op -> norm -> drop)
|
|
|
|
|
# post-norm branches (op -> norm -> drop)
|
|
|
|
|
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
|
|
|
|
|
x = x + self.drop_path2(self.norm2(self.mlp(x)))
|
|
|
|
|
x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
|
|
|
|
@ -522,6 +536,7 @@ class SwinTransformerStage(nn.Module):
|
|
|
|
|
feat_size: Tuple[int, int],
|
|
|
|
|
window_size: Tuple[int, int],
|
|
|
|
|
mlp_ratio: float = 4.0,
|
|
|
|
|
init_values: float = 0.0,
|
|
|
|
|
drop: float = 0.0,
|
|
|
|
|
drop_attn: float = 0.0,
|
|
|
|
|
drop_path: Union[List[float], float] = 0.0,
|
|
|
|
@ -552,6 +567,7 @@ class SwinTransformerStage(nn.Module):
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]),
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
init_values=init_values,
|
|
|
|
|
drop=drop,
|
|
|
|
|
drop_attn=drop_attn,
|
|
|
|
|
drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
|
|
|
|
@ -634,6 +650,7 @@ class SwinTransformerV2Cr(nn.Module):
|
|
|
|
|
depths: Tuple[int, ...] = (2, 2, 6, 2),
|
|
|
|
|
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
|
|
|
|
|
mlp_ratio: float = 4.0,
|
|
|
|
|
init_values: float = 0.0,
|
|
|
|
|
drop_rate: float = 0.0,
|
|
|
|
|
attn_drop_rate: float = 0.0,
|
|
|
|
|
drop_path_rate: float = 0.0,
|
|
|
|
@ -674,6 +691,7 @@ class SwinTransformerV2Cr(nn.Module):
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
init_values=init_values,
|
|
|
|
|
drop=drop_rate,
|
|
|
|
|
drop_attn=attn_drop_rate,
|
|
|
|
|
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
|
|
|
|
@ -786,6 +804,8 @@ def init_weights(module: nn.Module, name: str = ''):
|
|
|
|
|
nn.init.xavier_uniform_(module.weight)
|
|
|
|
|
if module.bias is not None:
|
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
|
elif hasattr(module, 'init_weights'):
|
|
|
|
|
module.init_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
|
|
|
|
@ -863,6 +883,20 @@ def swin_v2_cr_small_224(pretrained=False, **kwargs):
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_small_ns_224(pretrained=False, **kwargs):
|
|
|
|
|
"""Swin-S V2 CR @ 224x224, trained ImageNet-1k"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
embed_dim=96,
|
|
|
|
|
depths=(2, 2, 18, 2),
|
|
|
|
|
num_heads=(3, 6, 12, 24),
|
|
|
|
|
init_values=1e-5,
|
|
|
|
|
extra_norm_stage=True,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_small_ns_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_base_384(pretrained=False, **kwargs):
|
|
|
|
|
"""Swin-B V2 CR @ 384x384, trained ImageNet-1k"""
|
|
|
|
@ -887,6 +921,20 @@ def swin_v2_cr_base_224(pretrained=False, **kwargs):
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_base_ns_224(pretrained=False, **kwargs):
|
|
|
|
|
"""Swin-B V2 CR @ 224x224, trained ImageNet-1k"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
embed_dim=128,
|
|
|
|
|
depths=(2, 2, 18, 2),
|
|
|
|
|
num_heads=(4, 8, 16, 32),
|
|
|
|
|
init_values=1e-6,
|
|
|
|
|
extra_norm_stage=True,
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_base_ns_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def swin_v2_cr_large_384(pretrained=False, **kwargs):
|
|
|
|
|
"""Swin-L V2 CR @ 384x384, trained ImageNet-1k"""
|
|
|
|
|