diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index d7c36519..9659b5ec 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -12,7 +12,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # Licensed under The MIT License [see LICENSE for details] # Written by Christoph Reich # -------------------------------------------------------- -from typing import Tuple, Optional, List, Union, Any +from typing import Tuple, Optional, List, Union, Any, Type import torch import torch.nn as nn @@ -717,6 +717,7 @@ class SwinTransformerStage(nn.Module): dropout: float = 0.0, dropout_attention: float = 0.0, dropout_path: Union[List[float], float] = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, use_checkpoint: bool = False, sequential_self_attention: bool = False, use_deformable_block: bool = False) -> None: @@ -791,75 +792,78 @@ class SwinTransformerV2CR(nn.Module): https://arxiv.org/pdf/2111.09883 Args: - in_channels (int): Number of input channels - depth (int): Depth of the stage (number of layers) - downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper) - input_resolution (Tuple[int, int]): Input resolution - number_of_heads (int): Number of attention heads to be utilized - num_classes (int): Number of output classes - window_size (int): Window size to be utilized - shift_size (int): Shifting size to be used - ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels - dropout (float): Dropout in input mapping - dropout_attention (float): Dropout rate of attention map - dropout_path (float): Dropout in main path - use_checkpoint (bool): If true checkpointing is utilized - sequential_self_attention (bool): If true sequential self-attention is performed - use_deformable_block (bool): If true deformable block is used + img_size (Tuple[int, int]): Input resolution. + in_chans (int): Number of input channels. + depths (int): Depth of the stage (number of layers). + num_heads (int): Number of attention heads to be utilized. + embed_dim (int): Patch embedding dimension. Default: 96 + num_classes (int): Number of output classes. Default: 1000 + window_size (int): Window size to be utilized. Default: 7 + patch_size (int | tuple(int)): Patch size. Default: 4 + mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels. Default: 4 + drop_rate (float): Dropout rate. Default: 0.0 + attn_drop_rate (float): Dropout rate of attention map. Default: 0.0 + drop_path_rate (float): Stochastic depth rate. Default: 0.0 + norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm + use_checkpoint (bool): If true checkpointing is utilized. Default: False + sequential_self_attention (bool): If true sequential self-attention is performed. Default: False + use_deformable_block (bool): If true deformable block is used. Default: False """ def __init__(self, - in_channels: int, - embedding_channels: int, + img_size: Tuple[int, int], + in_chans: int, depths: Tuple[int, ...], - input_resolution: Tuple[int, int], - number_of_heads: Tuple[int, ...], + num_heads: Tuple[int, ...], + embed_dim: int = 96, num_classes: int = 1000, window_size: int = 7, patch_size: int = 4, - ff_feature_ratio: int = 4, - dropout: float = 0.0, - dropout_attention: float = 0.0, - dropout_path: float = 0.2, + mlp_ratio: int = 4, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: Type[nn.Module] = nn.LayerNorm, use_checkpoint: bool = False, sequential_self_attention: bool = False, - use_deformable_block: bool = False) -> None: + use_deformable_block: bool = False, + **kwargs: Any) -> None: # Call super constructor super(SwinTransformerV2CR, self).__init__() # Save parameters self.patch_size: int = patch_size - self.input_resolution: Tuple[int, int] = input_resolution + self.input_resolution: Tuple[int, int] = img_size self.window_size: int = window_size # Init patch embedding - self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_channels, out_channels=embedding_channels, + self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim, patch_size=patch_size) # Compute patch resolution - patch_resolution: Tuple[int, int] = (input_resolution[0] // patch_size, input_resolution[1] // patch_size) + patch_resolution: Tuple[int, int] = (img_size[0] // patch_size, img_size[1] // patch_size) # Path dropout dependent on depth - dropout_path = torch.linspace(0., dropout_path, sum(depths)).tolist() + drop_path_rate = torch.linspace(0., drop_path_rate, sum(depths)).tolist() # Init stages self.stages: nn.ModuleList = nn.ModuleList() - for index, (depth, number_of_head) in enumerate(zip(depths, number_of_heads)): + for index, (depth, number_of_head) in enumerate(zip(depths, num_heads)): self.stages.append( SwinTransformerStage( - in_channels=embedding_channels * (2 ** max(index - 1, 0)), + in_channels=embed_dim * (2 ** max(index - 1, 0)), depth=depth, downscale=index != 0, input_resolution=(patch_resolution[0] // (2 ** max(index - 1, 0)), patch_resolution[1] // (2 ** max(index - 1, 0))), number_of_heads=number_of_head, window_size=window_size, - ff_feature_ratio=ff_feature_ratio, - dropout=dropout, - dropout_attention=dropout_attention, - dropout_path=dropout_path[sum(depths[:index]):sum(depths[:index + 1])], + ff_feature_ratio=mlp_ratio, + dropout=drop_rate, + dropout_attention=attn_drop_rate, + dropout_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], use_checkpoint=use_checkpoint, sequential_self_attention=sequential_self_attention, use_deformable_block=use_deformable_block and (index > 0) )) # Init final adaptive average pooling, and classification head self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1) - self.head: nn.Module = nn.Linear(in_features=embedding_channels * (2 ** len(depths) - 1), + self.head: nn.Module = nn.Linear(in_features=embed_dim * (2 ** len(depths) - 1), out_features=num_classes) def update_resolution(self,