|
|
|
@ -408,6 +408,7 @@ class SwinTransformerBlock(nn.Module):
|
|
|
|
|
dropout_attention (float): Dropout rate of attention map
|
|
|
|
|
dropout_path (float): Dropout in main path
|
|
|
|
|
sequential_self_attention (bool): If true sequential self-attention is performed
|
|
|
|
|
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
@ -420,7 +421,8 @@ class SwinTransformerBlock(nn.Module):
|
|
|
|
|
dropout: float = 0.0,
|
|
|
|
|
dropout_attention: float = 0.0,
|
|
|
|
|
dropout_path: float = 0.0,
|
|
|
|
|
sequential_self_attention: bool = False) -> None:
|
|
|
|
|
sequential_self_attention: bool = False,
|
|
|
|
|
norm_layer: Type[nn.Module] = nn.LayerNorm) -> None:
|
|
|
|
|
# Call super constructor
|
|
|
|
|
super(SwinTransformerBlock, self).__init__()
|
|
|
|
|
# Save parameters
|
|
|
|
@ -436,8 +438,8 @@ class SwinTransformerBlock(nn.Module):
|
|
|
|
|
self.shift_size: int = shift_size
|
|
|
|
|
self.make_windows: bool = True
|
|
|
|
|
# Init normalization layers
|
|
|
|
|
self.normalization_1: nn.Module = nn.LayerNorm(normalized_shape=in_channels)
|
|
|
|
|
self.normalization_2: nn.Module = nn.LayerNorm(normalized_shape=in_channels)
|
|
|
|
|
self.normalization_1: nn.Module = norm_layer(normalized_shape=in_channels)
|
|
|
|
|
self.normalization_2: nn.Module = norm_layer(normalized_shape=in_channels)
|
|
|
|
|
# Init window attention module
|
|
|
|
|
self.window_attention: WindowMultiHeadAttention = WindowMultiHeadAttention(
|
|
|
|
|
in_features=in_channels,
|
|
|
|
@ -569,6 +571,7 @@ class DeformableSwinTransformerBlock(SwinTransformerBlock):
|
|
|
|
|
dropout_path (float): Dropout in main path
|
|
|
|
|
sequential_self_attention (bool): If true sequential self-attention is performed
|
|
|
|
|
offset_downscale_factor (int): Downscale factor of offset network
|
|
|
|
|
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
@ -582,7 +585,8 @@ class DeformableSwinTransformerBlock(SwinTransformerBlock):
|
|
|
|
|
dropout_attention: float = 0.0,
|
|
|
|
|
dropout_path: float = 0.0,
|
|
|
|
|
sequential_self_attention: bool = False,
|
|
|
|
|
offset_downscale_factor: int = 2) -> None:
|
|
|
|
|
offset_downscale_factor: int = 2,
|
|
|
|
|
norm_layer: Type[nn.Module] = nn.LayerNorm) -> None:
|
|
|
|
|
# Call super constructor
|
|
|
|
|
super(DeformableSwinTransformerBlock, self).__init__(
|
|
|
|
|
in_channels=in_channels,
|
|
|
|
@ -594,7 +598,8 @@ class DeformableSwinTransformerBlock(SwinTransformerBlock):
|
|
|
|
|
dropout=dropout,
|
|
|
|
|
dropout_attention=dropout_attention,
|
|
|
|
|
dropout_path=dropout_path,
|
|
|
|
|
sequential_self_attention=sequential_self_attention
|
|
|
|
|
sequential_self_attention=sequential_self_attention,
|
|
|
|
|
norm_layer=norm_layer
|
|
|
|
|
)
|
|
|
|
|
# Save parameter
|
|
|
|
|
self.offset_downscale_factor: int = offset_downscale_factor
|
|
|
|
@ -684,14 +689,16 @@ class PatchMerging(nn.Module):
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
in_channels (int): Number of input channels
|
|
|
|
|
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
in_channels: int) -> None:
|
|
|
|
|
in_channels: int,
|
|
|
|
|
norm_layer: Type[nn.Module] = nn.LayerNorm) -> None:
|
|
|
|
|
# Call super constructor
|
|
|
|
|
super(PatchMerging, self).__init__()
|
|
|
|
|
# Init normalization
|
|
|
|
|
self.normalization: nn.Module = nn.LayerNorm(normalized_shape=4 * in_channels)
|
|
|
|
|
self.normalization: nn.Module = norm_layer(normalized_shape=4 * in_channels)
|
|
|
|
|
# Init linear mapping
|
|
|
|
|
self.linear_mapping: nn.Module = nn.Linear(in_features=4 * in_channels, out_features=2 * in_channels,
|
|
|
|
|
bias=False)
|
|
|
|
@ -728,12 +735,14 @@ class PatchEmbedding(nn.Module):
|
|
|
|
|
out_channels (int): Number of output channels
|
|
|
|
|
patch_size (int): Patch size to be utilized
|
|
|
|
|
image_size (int): Image size to be used
|
|
|
|
|
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
in_channels: int = 3,
|
|
|
|
|
out_channels: int = 96,
|
|
|
|
|
patch_size: int = 4) -> None:
|
|
|
|
|
patch_size: int = 4,
|
|
|
|
|
norm_layer: Type[nn.Module] = nn.LayerNorm) -> None:
|
|
|
|
|
# Call super constructor
|
|
|
|
|
super(PatchEmbedding, self).__init__()
|
|
|
|
|
# Save parameters
|
|
|
|
@ -743,7 +752,7 @@ class PatchEmbedding(nn.Module):
|
|
|
|
|
kernel_size=(patch_size, patch_size),
|
|
|
|
|
stride=(patch_size, patch_size))
|
|
|
|
|
# Init layer normalization
|
|
|
|
|
self.normalization: nn.Module = nn.LayerNorm(normalized_shape=out_channels)
|
|
|
|
|
self.normalization: nn.Module = norm_layer(normalized_shape=out_channels)
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
input: torch.Tensor) -> torch.Tensor:
|
|
|
|
@ -777,6 +786,7 @@ class SwinTransformerStage(nn.Module):
|
|
|
|
|
dropout (float): Dropout in input mapping
|
|
|
|
|
dropout_attention (float): Dropout rate of attention map
|
|
|
|
|
dropout_path (float): Dropout in main path
|
|
|
|
|
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
|
|
|
|
|
use_checkpoint (bool): If true checkpointing is utilized
|
|
|
|
|
sequential_self_attention (bool): If true sequential self-attention is performed
|
|
|
|
|
use_deformable_block (bool): If true deformable block is used
|
|
|
|
@ -803,7 +813,8 @@ class SwinTransformerStage(nn.Module):
|
|
|
|
|
self.use_checkpoint: bool = use_checkpoint
|
|
|
|
|
self.downscale: bool = downscale
|
|
|
|
|
# Init downsampling
|
|
|
|
|
self.downsample: nn.Module = PatchMerging(in_channels=in_channels) if downscale else nn.Identity()
|
|
|
|
|
self.downsample: nn.Module = PatchMerging(in_channels=in_channels, norm_layer=norm_layer) \
|
|
|
|
|
if downscale else nn.Identity()
|
|
|
|
|
# Update resolution and channels
|
|
|
|
|
self.input_resolution: Tuple[int, int] = (input_resolution[0] // 2, input_resolution[1] // 2) \
|
|
|
|
|
if downscale else input_resolution
|
|
|
|
@ -821,7 +832,8 @@ class SwinTransformerStage(nn.Module):
|
|
|
|
|
dropout=dropout,
|
|
|
|
|
dropout_attention=dropout_attention,
|
|
|
|
|
dropout_path=dropout_path[index] if isinstance(dropout_path, list) else dropout_path,
|
|
|
|
|
sequential_self_attention=sequential_self_attention)
|
|
|
|
|
sequential_self_attention=sequential_self_attention,
|
|
|
|
|
norm_layer=norm_layer)
|
|
|
|
|
for index in range(depth)])
|
|
|
|
|
|
|
|
|
|
def update_resolution(self,
|
|
|
|
@ -914,7 +926,7 @@ class SwinTransformerV2CR(nn.Module):
|
|
|
|
|
self.num_features: int = int(embed_dim * (2 ** len(depths) - 1))
|
|
|
|
|
# Init patch embedding
|
|
|
|
|
self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim,
|
|
|
|
|
patch_size=patch_size)
|
|
|
|
|
patch_size=patch_size, norm_layer=norm_layer)
|
|
|
|
|
# Compute patch resolution
|
|
|
|
|
patch_resolution: Tuple[int, int] = (img_size[0] // patch_size, img_size[1] // patch_size)
|
|
|
|
|
# Path dropout dependent on depth
|
|
|
|
@ -937,7 +949,8 @@ class SwinTransformerV2CR(nn.Module):
|
|
|
|
|
dropout_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
|
|
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
|
|
sequential_self_attention=sequential_self_attention,
|
|
|
|
|
use_deformable_block=use_deformable_block and (index > 0)
|
|
|
|
|
use_deformable_block=use_deformable_block and (index > 0),
|
|
|
|
|
norm_layer=norm_layer
|
|
|
|
|
))
|
|
|
|
|
# Init final adaptive average pooling, and classification head
|
|
|
|
|
self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1)
|
|
|
|
@ -1165,17 +1178,3 @@ def swin_v2_cr_giant_patch4_window7_224(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(img_size=(224, 224), patch_size=4, window_size=7, embed_dim=512, depths=(2, 2, 42, 2),
|
|
|
|
|
num_heads=(16, 32, 64, 128), **kwargs)
|
|
|
|
|
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
model = swin_v2_cr_tiny_patch4_window12_384(pretrained=False)
|
|
|
|
|
model = swin_v2_cr_tiny_patch4_window7_224(pretrained=False)
|
|
|
|
|
|
|
|
|
|
model = swin_v2_cr_small_patch4_window12_384(pretrained=False)
|
|
|
|
|
model = swin_v2_cr_small_patch4_window7_224(pretrained=False)
|
|
|
|
|
|
|
|
|
|
model = swin_v2_cr_base_patch4_window12_384(pretrained=False)
|
|
|
|
|
model = swin_v2_cr_base_patch4_window7_224(pretrained=False)
|
|
|
|
|
|
|
|
|
|
model = swin_v2_cr_large_patch4_window12_384(pretrained=False)
|
|
|
|
|
model = swin_v2_cr_large_patch4_window7_224(pretrained=False)
|
|
|
|
|