Add Swin-V2 Small-NS weights (83.5 @ 224). Add layer scale like 'init_values' via post-norm LN weight scaling

pull/1245/head
Ross Wightman 3 years ago
parent 001688dabf
commit b7cb8d0337

@ -76,10 +76,15 @@ default_cfgs = {
'swin_v2_cr_small_224': _cfg( 'swin_v2_cr_small_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth", url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
input_size=(3, 224, 224), crop_pct=0.9), input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_small_ns_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_base_384': _cfg( 'swin_v2_cr_base_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_base_224': _cfg( 'swin_v2_cr_base_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_base_ns_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_large_384': _cfg( 'swin_v2_cr_large_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_large_224': _cfg( 'swin_v2_cr_large_224': _cfg(
@ -179,7 +184,7 @@ class WindowMultiHeadAttention(nn.Module):
hidden_features=meta_hidden_dim, hidden_features=meta_hidden_dim,
out_features=num_heads, out_features=num_heads,
act_layer=nn.ReLU, act_layer=nn.ReLU,
drop=0.1 # FIXME should there be stochasticity, appears to 'overfit' without? drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without?
) )
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads)))
self._make_pair_wise_relative_positions() self._make_pair_wise_relative_positions()
@ -304,6 +309,7 @@ class SwinTransformerBlock(nn.Module):
window_size: Tuple[int, int], window_size: Tuple[int, int],
shift_size: Tuple[int, int] = (0, 0), shift_size: Tuple[int, int] = (0, 0),
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
init_values: float = 0,
drop: float = 0.0, drop: float = 0.0,
drop_attn: float = 0.0, drop_attn: float = 0.0,
drop_path: float = 0.0, drop_path: float = 0.0,
@ -317,6 +323,7 @@ class SwinTransformerBlock(nn.Module):
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
self.window_area = self.window_size[0] * self.window_size[1] self.window_area = self.window_size[0] * self.window_size[1]
self.init_values: float = init_values
# attn branch # attn branch
self.attn = WindowMultiHeadAttention( self.attn = WindowMultiHeadAttention(
@ -345,6 +352,7 @@ class SwinTransformerBlock(nn.Module):
self.norm3 = norm_layer(dim) if extra_norm else nn.Identity() self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()
self._make_attention_mask() self._make_attention_mask()
self.init_weights()
def _calc_window_shift(self, target_window_size): def _calc_window_shift(self, target_window_size):
window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)] window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)]
@ -377,6 +385,12 @@ class SwinTransformerBlock(nn.Module):
attn_mask = None attn_mask = None
self.register_buffer("attn_mask", attn_mask, persistent=False) self.register_buffer("attn_mask", attn_mask, persistent=False)
def init_weights(self):
# extra, module specific weight init
if self.init_values:
nn.init.constant_(self.norm1.weight, self.init_values)
nn.init.constant_(self.norm2.weight, self.init_values)
def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None: def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None:
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions. """Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
@ -435,7 +449,7 @@ class SwinTransformerBlock(nn.Module):
Returns: Returns:
output (torch.Tensor): Output tensor of the shape [B, C, H, W] output (torch.Tensor): Output tensor of the shape [B, C, H, W]
""" """
# NOTE post-norm branches (op -> norm -> drop) # post-norm branches (op -> norm -> drop)
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x))) x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
x = x + self.drop_path2(self.norm2(self.mlp(x))) x = x + self.drop_path2(self.norm2(self.mlp(x)))
x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant) x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
@ -522,6 +536,7 @@ class SwinTransformerStage(nn.Module):
feat_size: Tuple[int, int], feat_size: Tuple[int, int],
window_size: Tuple[int, int], window_size: Tuple[int, int],
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
init_values: float = 0.0,
drop: float = 0.0, drop: float = 0.0,
drop_attn: float = 0.0, drop_attn: float = 0.0,
drop_path: Union[List[float], float] = 0.0, drop_path: Union[List[float], float] = 0.0,
@ -552,6 +567,7 @@ class SwinTransformerStage(nn.Module):
window_size=window_size, window_size=window_size,
shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]), shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]),
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
init_values=init_values,
drop=drop, drop=drop,
drop_attn=drop_attn, drop_attn=drop_attn,
drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path, drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
@ -634,6 +650,7 @@ class SwinTransformerV2Cr(nn.Module):
depths: Tuple[int, ...] = (2, 2, 6, 2), depths: Tuple[int, ...] = (2, 2, 6, 2),
num_heads: Tuple[int, ...] = (3, 6, 12, 24), num_heads: Tuple[int, ...] = (3, 6, 12, 24),
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
init_values: float = 0.0,
drop_rate: float = 0.0, drop_rate: float = 0.0,
attn_drop_rate: float = 0.0, attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0, drop_path_rate: float = 0.0,
@ -674,6 +691,7 @@ class SwinTransformerV2Cr(nn.Module):
num_heads=num_heads, num_heads=num_heads,
window_size=window_size, window_size=window_size,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
init_values=init_values,
drop=drop_rate, drop=drop_rate,
drop_attn=attn_drop_rate, drop_attn=attn_drop_rate,
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
@ -786,6 +804,8 @@ def init_weights(module: nn.Module, name: str = ''):
nn.init.xavier_uniform_(module.weight) nn.init.xavier_uniform_(module.weight)
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
@ -863,6 +883,20 @@ def swin_v2_cr_small_224(pretrained=False, **kwargs):
return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_small_ns_224(pretrained=False, **kwargs):
"""Swin-S V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
init_values=1e-5,
extra_norm_stage=True,
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_small_ns_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_base_384(pretrained=False, **kwargs): def swin_v2_cr_base_384(pretrained=False, **kwargs):
"""Swin-B V2 CR @ 384x384, trained ImageNet-1k""" """Swin-B V2 CR @ 384x384, trained ImageNet-1k"""
@ -887,6 +921,20 @@ def swin_v2_cr_base_224(pretrained=False, **kwargs):
return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs)
@register_model
def swin_v2_cr_base_ns_224(pretrained=False, **kwargs):
"""Swin-B V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
init_values=1e-6,
extra_norm_stage=True,
**kwargs
)
return _create_swin_transformer_v2_cr('swin_v2_cr_base_ns_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_large_384(pretrained=False, **kwargs): def swin_v2_cr_large_384(pretrained=False, **kwargs):
"""Swin-L V2 CR @ 384x384, trained ImageNet-1k""" """Swin-L V2 CR @ 384x384, trained ImageNet-1k"""

Loading…
Cancel
Save