|
|
@ -12,7 +12,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
|
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
|
|
# Written by Christoph Reich
|
|
|
|
# Written by Christoph Reich
|
|
|
|
# --------------------------------------------------------
|
|
|
|
# --------------------------------------------------------
|
|
|
|
from typing import Tuple, Optional, List, Union, Any
|
|
|
|
from typing import Tuple, Optional, List, Union, Any, Type
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
@ -717,6 +717,7 @@ class SwinTransformerStage(nn.Module):
|
|
|
|
dropout: float = 0.0,
|
|
|
|
dropout: float = 0.0,
|
|
|
|
dropout_attention: float = 0.0,
|
|
|
|
dropout_attention: float = 0.0,
|
|
|
|
dropout_path: Union[List[float], float] = 0.0,
|
|
|
|
dropout_path: Union[List[float], float] = 0.0,
|
|
|
|
|
|
|
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
|
|
|
use_checkpoint: bool = False,
|
|
|
|
use_checkpoint: bool = False,
|
|
|
|
sequential_self_attention: bool = False,
|
|
|
|
sequential_self_attention: bool = False,
|
|
|
|
use_deformable_block: bool = False) -> None:
|
|
|
|
use_deformable_block: bool = False) -> None:
|
|
|
@ -791,75 +792,78 @@ class SwinTransformerV2CR(nn.Module):
|
|
|
|
https://arxiv.org/pdf/2111.09883
|
|
|
|
https://arxiv.org/pdf/2111.09883
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
in_channels (int): Number of input channels
|
|
|
|
img_size (Tuple[int, int]): Input resolution.
|
|
|
|
depth (int): Depth of the stage (number of layers)
|
|
|
|
in_chans (int): Number of input channels.
|
|
|
|
downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper)
|
|
|
|
depths (int): Depth of the stage (number of layers).
|
|
|
|
input_resolution (Tuple[int, int]): Input resolution
|
|
|
|
num_heads (int): Number of attention heads to be utilized.
|
|
|
|
number_of_heads (int): Number of attention heads to be utilized
|
|
|
|
embed_dim (int): Patch embedding dimension. Default: 96
|
|
|
|
num_classes (int): Number of output classes
|
|
|
|
num_classes (int): Number of output classes. Default: 1000
|
|
|
|
window_size (int): Window size to be utilized
|
|
|
|
window_size (int): Window size to be utilized. Default: 7
|
|
|
|
shift_size (int): Shifting size to be used
|
|
|
|
patch_size (int | tuple(int)): Patch size. Default: 4
|
|
|
|
ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
|
|
|
|
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels. Default: 4
|
|
|
|
dropout (float): Dropout in input mapping
|
|
|
|
drop_rate (float): Dropout rate. Default: 0.0
|
|
|
|
dropout_attention (float): Dropout rate of attention map
|
|
|
|
attn_drop_rate (float): Dropout rate of attention map. Default: 0.0
|
|
|
|
dropout_path (float): Dropout in main path
|
|
|
|
drop_path_rate (float): Stochastic depth rate. Default: 0.0
|
|
|
|
use_checkpoint (bool): If true checkpointing is utilized
|
|
|
|
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
|
|
|
|
sequential_self_attention (bool): If true sequential self-attention is performed
|
|
|
|
use_checkpoint (bool): If true checkpointing is utilized. Default: False
|
|
|
|
use_deformable_block (bool): If true deformable block is used
|
|
|
|
sequential_self_attention (bool): If true sequential self-attention is performed. Default: False
|
|
|
|
|
|
|
|
use_deformable_block (bool): If true deformable block is used. Default: False
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
in_channels: int,
|
|
|
|
img_size: Tuple[int, int],
|
|
|
|
embedding_channels: int,
|
|
|
|
in_chans: int,
|
|
|
|
depths: Tuple[int, ...],
|
|
|
|
depths: Tuple[int, ...],
|
|
|
|
input_resolution: Tuple[int, int],
|
|
|
|
num_heads: Tuple[int, ...],
|
|
|
|
number_of_heads: Tuple[int, ...],
|
|
|
|
embed_dim: int = 96,
|
|
|
|
num_classes: int = 1000,
|
|
|
|
num_classes: int = 1000,
|
|
|
|
window_size: int = 7,
|
|
|
|
window_size: int = 7,
|
|
|
|
patch_size: int = 4,
|
|
|
|
patch_size: int = 4,
|
|
|
|
ff_feature_ratio: int = 4,
|
|
|
|
mlp_ratio: int = 4,
|
|
|
|
dropout: float = 0.0,
|
|
|
|
drop_rate: float = 0.0,
|
|
|
|
dropout_attention: float = 0.0,
|
|
|
|
attn_drop_rate: float = 0.0,
|
|
|
|
dropout_path: float = 0.2,
|
|
|
|
drop_path_rate: float = 0.0,
|
|
|
|
|
|
|
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
|
|
|
use_checkpoint: bool = False,
|
|
|
|
use_checkpoint: bool = False,
|
|
|
|
sequential_self_attention: bool = False,
|
|
|
|
sequential_self_attention: bool = False,
|
|
|
|
use_deformable_block: bool = False) -> None:
|
|
|
|
use_deformable_block: bool = False,
|
|
|
|
|
|
|
|
**kwargs: Any) -> None:
|
|
|
|
# Call super constructor
|
|
|
|
# Call super constructor
|
|
|
|
super(SwinTransformerV2CR, self).__init__()
|
|
|
|
super(SwinTransformerV2CR, self).__init__()
|
|
|
|
# Save parameters
|
|
|
|
# Save parameters
|
|
|
|
self.patch_size: int = patch_size
|
|
|
|
self.patch_size: int = patch_size
|
|
|
|
self.input_resolution: Tuple[int, int] = input_resolution
|
|
|
|
self.input_resolution: Tuple[int, int] = img_size
|
|
|
|
self.window_size: int = window_size
|
|
|
|
self.window_size: int = window_size
|
|
|
|
# Init patch embedding
|
|
|
|
# Init patch embedding
|
|
|
|
self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_channels, out_channels=embedding_channels,
|
|
|
|
self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim,
|
|
|
|
patch_size=patch_size)
|
|
|
|
patch_size=patch_size)
|
|
|
|
# Compute patch resolution
|
|
|
|
# Compute patch resolution
|
|
|
|
patch_resolution: Tuple[int, int] = (input_resolution[0] // patch_size, input_resolution[1] // patch_size)
|
|
|
|
patch_resolution: Tuple[int, int] = (img_size[0] // patch_size, img_size[1] // patch_size)
|
|
|
|
# Path dropout dependent on depth
|
|
|
|
# Path dropout dependent on depth
|
|
|
|
dropout_path = torch.linspace(0., dropout_path, sum(depths)).tolist()
|
|
|
|
drop_path_rate = torch.linspace(0., drop_path_rate, sum(depths)).tolist()
|
|
|
|
# Init stages
|
|
|
|
# Init stages
|
|
|
|
self.stages: nn.ModuleList = nn.ModuleList()
|
|
|
|
self.stages: nn.ModuleList = nn.ModuleList()
|
|
|
|
for index, (depth, number_of_head) in enumerate(zip(depths, number_of_heads)):
|
|
|
|
for index, (depth, number_of_head) in enumerate(zip(depths, num_heads)):
|
|
|
|
self.stages.append(
|
|
|
|
self.stages.append(
|
|
|
|
SwinTransformerStage(
|
|
|
|
SwinTransformerStage(
|
|
|
|
in_channels=embedding_channels * (2 ** max(index - 1, 0)),
|
|
|
|
in_channels=embed_dim * (2 ** max(index - 1, 0)),
|
|
|
|
depth=depth,
|
|
|
|
depth=depth,
|
|
|
|
downscale=index != 0,
|
|
|
|
downscale=index != 0,
|
|
|
|
input_resolution=(patch_resolution[0] // (2 ** max(index - 1, 0)),
|
|
|
|
input_resolution=(patch_resolution[0] // (2 ** max(index - 1, 0)),
|
|
|
|
patch_resolution[1] // (2 ** max(index - 1, 0))),
|
|
|
|
patch_resolution[1] // (2 ** max(index - 1, 0))),
|
|
|
|
number_of_heads=number_of_head,
|
|
|
|
number_of_heads=number_of_head,
|
|
|
|
window_size=window_size,
|
|
|
|
window_size=window_size,
|
|
|
|
ff_feature_ratio=ff_feature_ratio,
|
|
|
|
ff_feature_ratio=mlp_ratio,
|
|
|
|
dropout=dropout,
|
|
|
|
dropout=drop_rate,
|
|
|
|
dropout_attention=dropout_attention,
|
|
|
|
dropout_attention=attn_drop_rate,
|
|
|
|
dropout_path=dropout_path[sum(depths[:index]):sum(depths[:index + 1])],
|
|
|
|
dropout_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
|
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
|
use_checkpoint=use_checkpoint,
|
|
|
|
sequential_self_attention=sequential_self_attention,
|
|
|
|
sequential_self_attention=sequential_self_attention,
|
|
|
|
use_deformable_block=use_deformable_block and (index > 0)
|
|
|
|
use_deformable_block=use_deformable_block and (index > 0)
|
|
|
|
))
|
|
|
|
))
|
|
|
|
# Init final adaptive average pooling, and classification head
|
|
|
|
# Init final adaptive average pooling, and classification head
|
|
|
|
self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1)
|
|
|
|
self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1)
|
|
|
|
self.head: nn.Module = nn.Linear(in_features=embedding_channels * (2 ** len(depths) - 1),
|
|
|
|
self.head: nn.Module = nn.Linear(in_features=embed_dim * (2 ** len(depths) - 1),
|
|
|
|
out_features=num_classes)
|
|
|
|
out_features=num_classes)
|
|
|
|
|
|
|
|
|
|
|
|
def update_resolution(self,
|
|
|
|
def update_resolution(self,
|
|
|
|