From 74a04e0016f6f3e3b6484236ab8815561e3c5a86 Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sun, 20 Feb 2022 00:46:00 +0100 Subject: [PATCH] Add parameter to change normalization type --- timm/models/swin_transformer_v2_cr.py | 53 +++++++++++++-------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 9aad19c0..6df1ff9b 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -408,6 +408,7 @@ class SwinTransformerBlock(nn.Module): dropout_attention (float): Dropout rate of attention map dropout_path (float): Dropout in main path sequential_self_attention (bool): If true sequential self-attention is performed + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized """ def __init__(self, @@ -420,7 +421,8 @@ class SwinTransformerBlock(nn.Module): dropout: float = 0.0, dropout_attention: float = 0.0, dropout_path: float = 0.0, - sequential_self_attention: bool = False) -> None: + sequential_self_attention: bool = False, + norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: # Call super constructor super(SwinTransformerBlock, self).__init__() # Save parameters @@ -436,8 +438,8 @@ class SwinTransformerBlock(nn.Module): self.shift_size: int = shift_size self.make_windows: bool = True # Init normalization layers - self.normalization_1: nn.Module = nn.LayerNorm(normalized_shape=in_channels) - self.normalization_2: nn.Module = nn.LayerNorm(normalized_shape=in_channels) + self.normalization_1: nn.Module = norm_layer(normalized_shape=in_channels) + self.normalization_2: nn.Module = norm_layer(normalized_shape=in_channels) # Init window attention module self.window_attention: WindowMultiHeadAttention = WindowMultiHeadAttention( in_features=in_channels, @@ -569,6 +571,7 @@ class DeformableSwinTransformerBlock(SwinTransformerBlock): dropout_path (float): Dropout in main path sequential_self_attention (bool): If true sequential self-attention is performed offset_downscale_factor (int): Downscale factor of offset network + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized """ def __init__(self, @@ -582,7 +585,8 @@ class DeformableSwinTransformerBlock(SwinTransformerBlock): dropout_attention: float = 0.0, dropout_path: float = 0.0, sequential_self_attention: bool = False, - offset_downscale_factor: int = 2) -> None: + offset_downscale_factor: int = 2, + norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: # Call super constructor super(DeformableSwinTransformerBlock, self).__init__( in_channels=in_channels, @@ -594,7 +598,8 @@ class DeformableSwinTransformerBlock(SwinTransformerBlock): dropout=dropout, dropout_attention=dropout_attention, dropout_path=dropout_path, - sequential_self_attention=sequential_self_attention + sequential_self_attention=sequential_self_attention, + norm_layer=norm_layer ) # Save parameter self.offset_downscale_factor: int = offset_downscale_factor @@ -684,14 +689,16 @@ class PatchMerging(nn.Module): Args: in_channels (int): Number of input channels + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. """ def __init__(self, - in_channels: int) -> None: + in_channels: int, + norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: # Call super constructor super(PatchMerging, self).__init__() # Init normalization - self.normalization: nn.Module = nn.LayerNorm(normalized_shape=4 * in_channels) + self.normalization: nn.Module = norm_layer(normalized_shape=4 * in_channels) # Init linear mapping self.linear_mapping: nn.Module = nn.Linear(in_features=4 * in_channels, out_features=2 * in_channels, bias=False) @@ -728,12 +735,14 @@ class PatchEmbedding(nn.Module): out_channels (int): Number of output channels patch_size (int): Patch size to be utilized image_size (int): Image size to be used + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized """ def __init__(self, in_channels: int = 3, out_channels: int = 96, - patch_size: int = 4) -> None: + patch_size: int = 4, + norm_layer: Type[nn.Module] = nn.LayerNorm) -> None: # Call super constructor super(PatchEmbedding, self).__init__() # Save parameters @@ -743,7 +752,7 @@ class PatchEmbedding(nn.Module): kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size)) # Init layer normalization - self.normalization: nn.Module = nn.LayerNorm(normalized_shape=out_channels) + self.normalization: nn.Module = norm_layer(normalized_shape=out_channels) def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -777,6 +786,7 @@ class SwinTransformerStage(nn.Module): dropout (float): Dropout in input mapping dropout_attention (float): Dropout rate of attention map dropout_path (float): Dropout in main path + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm use_checkpoint (bool): If true checkpointing is utilized sequential_self_attention (bool): If true sequential self-attention is performed use_deformable_block (bool): If true deformable block is used @@ -803,7 +813,8 @@ class SwinTransformerStage(nn.Module): self.use_checkpoint: bool = use_checkpoint self.downscale: bool = downscale # Init downsampling - self.downsample: nn.Module = PatchMerging(in_channels=in_channels) if downscale else nn.Identity() + self.downsample: nn.Module = PatchMerging(in_channels=in_channels, norm_layer=norm_layer) \ + if downscale else nn.Identity() # Update resolution and channels self.input_resolution: Tuple[int, int] = (input_resolution[0] // 2, input_resolution[1] // 2) \ if downscale else input_resolution @@ -821,7 +832,8 @@ class SwinTransformerStage(nn.Module): dropout=dropout, dropout_attention=dropout_attention, dropout_path=dropout_path[index] if isinstance(dropout_path, list) else dropout_path, - sequential_self_attention=sequential_self_attention) + sequential_self_attention=sequential_self_attention, + norm_layer=norm_layer) for index in range(depth)]) def update_resolution(self, @@ -914,7 +926,7 @@ class SwinTransformerV2CR(nn.Module): self.num_features: int = int(embed_dim * (2 ** len(depths) - 1)) # Init patch embedding self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim, - patch_size=patch_size) + patch_size=patch_size, norm_layer=norm_layer) # Compute patch resolution patch_resolution: Tuple[int, int] = (img_size[0] // patch_size, img_size[1] // patch_size) # Path dropout dependent on depth @@ -937,7 +949,8 @@ class SwinTransformerV2CR(nn.Module): dropout_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], use_checkpoint=use_checkpoint, sequential_self_attention=sequential_self_attention, - use_deformable_block=use_deformable_block and (index > 0) + use_deformable_block=use_deformable_block and (index > 0), + norm_layer=norm_layer )) # Init final adaptive average pooling, and classification head self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1) @@ -1165,17 +1178,3 @@ def swin_v2_cr_giant_patch4_window7_224(pretrained=False, **kwargs): model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=512, depths=(2, 2, 42, 2), num_heads=(16, 32, 64, 128), **kwargs) return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window7_224', pretrained=pretrained, **model_kwargs) - - -if __name__ == '__main__': - model = swin_v2_cr_tiny_patch4_window12_384(pretrained=False) - model = swin_v2_cr_tiny_patch4_window7_224(pretrained=False) - - model = swin_v2_cr_small_patch4_window12_384(pretrained=False) - model = swin_v2_cr_small_patch4_window7_224(pretrained=False) - - model = swin_v2_cr_base_patch4_window12_384(pretrained=False) - model = swin_v2_cr_base_patch4_window7_224(pretrained=False) - - model = swin_v2_cr_large_patch4_window12_384(pretrained=False) - model = swin_v2_cr_large_patch4_window7_224(pretrained=False)