diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index bb8fe3cc..472ae205 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -76,10 +76,15 @@ default_cfgs = { 'swin_v2_cr_small_224': _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth", input_size=(3, 224, 224), crop_pct=0.9), + 'swin_v2_cr_small_ns_224': _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth", + input_size=(3, 224, 224), crop_pct=0.9), 'swin_v2_cr_base_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), 'swin_v2_cr_base_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), + 'swin_v2_cr_base_ns_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=0.9), 'swin_v2_cr_large_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), 'swin_v2_cr_large_224': _cfg( @@ -179,7 +184,7 @@ class WindowMultiHeadAttention(nn.Module): hidden_features=meta_hidden_dim, out_features=num_heads, act_layer=nn.ReLU, - drop=0.1 # FIXME should there be stochasticity, appears to 'overfit' without? + drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without? ) self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) self._make_pair_wise_relative_positions() @@ -304,6 +309,7 @@ class SwinTransformerBlock(nn.Module): window_size: Tuple[int, int], shift_size: Tuple[int, int] = (0, 0), mlp_ratio: float = 4.0, + init_values: float = 0, drop: float = 0.0, drop_attn: float = 0.0, drop_path: float = 0.0, @@ -317,6 +323,7 @@ class SwinTransformerBlock(nn.Module): self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) self.window_area = self.window_size[0] * self.window_size[1] + self.init_values: float = init_values # attn branch self.attn = WindowMultiHeadAttention( @@ -345,6 +352,7 @@ class SwinTransformerBlock(nn.Module): self.norm3 = norm_layer(dim) if extra_norm else nn.Identity() self._make_attention_mask() + self.init_weights() def _calc_window_shift(self, target_window_size): window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)] @@ -377,6 +385,12 @@ class SwinTransformerBlock(nn.Module): attn_mask = None self.register_buffer("attn_mask", attn_mask, persistent=False) + def init_weights(self): + # extra, module specific weight init + if self.init_values: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None: """Method updates the image resolution to be processed and window size and so the pair-wise relative positions. @@ -435,7 +449,7 @@ class SwinTransformerBlock(nn.Module): Returns: output (torch.Tensor): Output tensor of the shape [B, C, H, W] """ - # NOTE post-norm branches (op -> norm -> drop) + # post-norm branches (op -> norm -> drop) x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x))) x = x + self.drop_path2(self.norm2(self.mlp(x))) x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant) @@ -522,6 +536,7 @@ class SwinTransformerStage(nn.Module): feat_size: Tuple[int, int], window_size: Tuple[int, int], mlp_ratio: float = 4.0, + init_values: float = 0.0, drop: float = 0.0, drop_attn: float = 0.0, drop_path: Union[List[float], float] = 0.0, @@ -552,6 +567,7 @@ class SwinTransformerStage(nn.Module): window_size=window_size, shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]), mlp_ratio=mlp_ratio, + init_values=init_values, drop=drop, drop_attn=drop_attn, drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path, @@ -634,6 +650,7 @@ class SwinTransformerV2Cr(nn.Module): depths: Tuple[int, ...] = (2, 2, 6, 2), num_heads: Tuple[int, ...] = (3, 6, 12, 24), mlp_ratio: float = 4.0, + init_values: float = 0.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, @@ -674,6 +691,7 @@ class SwinTransformerV2Cr(nn.Module): num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, + init_values=init_values, drop=drop_rate, drop_attn=attn_drop_rate, drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], @@ -786,6 +804,8 @@ def init_weights(module: nn.Module, name: str = ''): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): @@ -863,6 +883,20 @@ def swin_v2_cr_small_224(pretrained=False, **kwargs): return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs) +@register_model +def swin_v2_cr_small_ns_224(pretrained=False, **kwargs): + """Swin-S V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), + init_values=1e-5, + extra_norm_stage=True, + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_small_ns_224', pretrained=pretrained, **model_kwargs) + + @register_model def swin_v2_cr_base_384(pretrained=False, **kwargs): """Swin-B V2 CR @ 384x384, trained ImageNet-1k""" @@ -887,6 +921,20 @@ def swin_v2_cr_base_224(pretrained=False, **kwargs): return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs) +@register_model +def swin_v2_cr_base_ns_224(pretrained=False, **kwargs): + """Swin-B V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=128, + depths=(2, 2, 18, 2), + num_heads=(4, 8, 16, 32), + init_values=1e-6, + extra_norm_stage=True, + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_base_ns_224', pretrained=pretrained, **model_kwargs) + + @register_model def swin_v2_cr_large_384(pretrained=False, **kwargs): """Swin-L V2 CR @ 384x384, trained ImageNet-1k"""