Change parameter names to match Swin V1

pull/1150/head
Christoph Reich 3 years ago
parent f227b88831
commit 81bf0b4033

@ -12,7 +12,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# Licensed under The MIT License [see LICENSE for details]
# Written by Christoph Reich
# --------------------------------------------------------
from typing import Tuple, Optional, List, Union, Any
from typing import Tuple, Optional, List, Union, Any, Type
import torch
import torch.nn as nn
@ -717,6 +717,7 @@ class SwinTransformerStage(nn.Module):
dropout: float = 0.0,
dropout_attention: float = 0.0,
dropout_path: Union[List[float], float] = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
use_checkpoint: bool = False,
sequential_self_attention: bool = False,
use_deformable_block: bool = False) -> None:
@ -791,75 +792,78 @@ class SwinTransformerV2CR(nn.Module):
https://arxiv.org/pdf/2111.09883
Args:
in_channels (int): Number of input channels
depth (int): Depth of the stage (number of layers)
downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper)
input_resolution (Tuple[int, int]): Input resolution
number_of_heads (int): Number of attention heads to be utilized
num_classes (int): Number of output classes
window_size (int): Window size to be utilized
shift_size (int): Shifting size to be used
ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
dropout (float): Dropout in input mapping
dropout_attention (float): Dropout rate of attention map
dropout_path (float): Dropout in main path
use_checkpoint (bool): If true checkpointing is utilized
sequential_self_attention (bool): If true sequential self-attention is performed
use_deformable_block (bool): If true deformable block is used
img_size (Tuple[int, int]): Input resolution.
in_chans (int): Number of input channels.
depths (int): Depth of the stage (number of layers).
num_heads (int): Number of attention heads to be utilized.
embed_dim (int): Patch embedding dimension. Default: 96
num_classes (int): Number of output classes. Default: 1000
window_size (int): Window size to be utilized. Default: 7
patch_size (int | tuple(int)): Patch size. Default: 4
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels. Default: 4
drop_rate (float): Dropout rate. Default: 0.0
attn_drop_rate (float): Dropout rate of attention map. Default: 0.0
drop_path_rate (float): Stochastic depth rate. Default: 0.0
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
use_checkpoint (bool): If true checkpointing is utilized. Default: False
sequential_self_attention (bool): If true sequential self-attention is performed. Default: False
use_deformable_block (bool): If true deformable block is used. Default: False
"""
def __init__(self,
in_channels: int,
embedding_channels: int,
img_size: Tuple[int, int],
in_chans: int,
depths: Tuple[int, ...],
input_resolution: Tuple[int, int],
number_of_heads: Tuple[int, ...],
num_heads: Tuple[int, ...],
embed_dim: int = 96,
num_classes: int = 1000,
window_size: int = 7,
patch_size: int = 4,
ff_feature_ratio: int = 4,
dropout: float = 0.0,
dropout_attention: float = 0.0,
dropout_path: float = 0.2,
mlp_ratio: int = 4,
drop_rate: float = 0.0,
attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
use_checkpoint: bool = False,
sequential_self_attention: bool = False,
use_deformable_block: bool = False) -> None:
use_deformable_block: bool = False,
**kwargs: Any) -> None:
# Call super constructor
super(SwinTransformerV2CR, self).__init__()
# Save parameters
self.patch_size: int = patch_size
self.input_resolution: Tuple[int, int] = input_resolution
self.input_resolution: Tuple[int, int] = img_size
self.window_size: int = window_size
# Init patch embedding
self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_channels, out_channels=embedding_channels,
self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim,
patch_size=patch_size)
# Compute patch resolution
patch_resolution: Tuple[int, int] = (input_resolution[0] // patch_size, input_resolution[1] // patch_size)
patch_resolution: Tuple[int, int] = (img_size[0] // patch_size, img_size[1] // patch_size)
# Path dropout dependent on depth
dropout_path = torch.linspace(0., dropout_path, sum(depths)).tolist()
drop_path_rate = torch.linspace(0., drop_path_rate, sum(depths)).tolist()
# Init stages
self.stages: nn.ModuleList = nn.ModuleList()
for index, (depth, number_of_head) in enumerate(zip(depths, number_of_heads)):
for index, (depth, number_of_head) in enumerate(zip(depths, num_heads)):
self.stages.append(
SwinTransformerStage(
in_channels=embedding_channels * (2 ** max(index - 1, 0)),
in_channels=embed_dim * (2 ** max(index - 1, 0)),
depth=depth,
downscale=index != 0,
input_resolution=(patch_resolution[0] // (2 ** max(index - 1, 0)),
patch_resolution[1] // (2 ** max(index - 1, 0))),
number_of_heads=number_of_head,
window_size=window_size,
ff_feature_ratio=ff_feature_ratio,
dropout=dropout,
dropout_attention=dropout_attention,
dropout_path=dropout_path[sum(depths[:index]):sum(depths[:index + 1])],
ff_feature_ratio=mlp_ratio,
dropout=drop_rate,
dropout_attention=attn_drop_rate,
dropout_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
use_checkpoint=use_checkpoint,
sequential_self_attention=sequential_self_attention,
use_deformable_block=use_deformable_block and (index > 0)
))
# Init final adaptive average pooling, and classification head
self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1)
self.head: nn.Module = nn.Linear(in_features=embedding_channels * (2 ** len(depths) - 1),
self.head: nn.Module = nn.Linear(in_features=embed_dim * (2 ** len(depths) - 1),
out_features=num_classes)
def update_resolution(self,

Loading…
Cancel
Save