Change parameter names to match Swin V1

pull/1150/head
Christoph Reich 3 years ago
parent f227b88831
commit 81bf0b4033

@ -12,7 +12,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# Written by Christoph Reich # Written by Christoph Reich
# -------------------------------------------------------- # --------------------------------------------------------
from typing import Tuple, Optional, List, Union, Any from typing import Tuple, Optional, List, Union, Any, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -717,6 +717,7 @@ class SwinTransformerStage(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
dropout_attention: float = 0.0, dropout_attention: float = 0.0,
dropout_path: Union[List[float], float] = 0.0, dropout_path: Union[List[float], float] = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
use_checkpoint: bool = False, use_checkpoint: bool = False,
sequential_self_attention: bool = False, sequential_self_attention: bool = False,
use_deformable_block: bool = False) -> None: use_deformable_block: bool = False) -> None:
@ -791,75 +792,78 @@ class SwinTransformerV2CR(nn.Module):
https://arxiv.org/pdf/2111.09883 https://arxiv.org/pdf/2111.09883
Args: Args:
in_channels (int): Number of input channels img_size (Tuple[int, int]): Input resolution.
depth (int): Depth of the stage (number of layers) in_chans (int): Number of input channels.
downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper) depths (int): Depth of the stage (number of layers).
input_resolution (Tuple[int, int]): Input resolution num_heads (int): Number of attention heads to be utilized.
number_of_heads (int): Number of attention heads to be utilized embed_dim (int): Patch embedding dimension. Default: 96
num_classes (int): Number of output classes num_classes (int): Number of output classes. Default: 1000
window_size (int): Window size to be utilized window_size (int): Window size to be utilized. Default: 7
shift_size (int): Shifting size to be used patch_size (int | tuple(int)): Patch size. Default: 4
ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels. Default: 4
dropout (float): Dropout in input mapping drop_rate (float): Dropout rate. Default: 0.0
dropout_attention (float): Dropout rate of attention map attn_drop_rate (float): Dropout rate of attention map. Default: 0.0
dropout_path (float): Dropout in main path drop_path_rate (float): Stochastic depth rate. Default: 0.0
use_checkpoint (bool): If true checkpointing is utilized norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
sequential_self_attention (bool): If true sequential self-attention is performed use_checkpoint (bool): If true checkpointing is utilized. Default: False
use_deformable_block (bool): If true deformable block is used sequential_self_attention (bool): If true sequential self-attention is performed. Default: False
use_deformable_block (bool): If true deformable block is used. Default: False
""" """
def __init__(self, def __init__(self,
in_channels: int, img_size: Tuple[int, int],
embedding_channels: int, in_chans: int,
depths: Tuple[int, ...], depths: Tuple[int, ...],
input_resolution: Tuple[int, int], num_heads: Tuple[int, ...],
number_of_heads: Tuple[int, ...], embed_dim: int = 96,
num_classes: int = 1000, num_classes: int = 1000,
window_size: int = 7, window_size: int = 7,
patch_size: int = 4, patch_size: int = 4,
ff_feature_ratio: int = 4, mlp_ratio: int = 4,
dropout: float = 0.0, drop_rate: float = 0.0,
dropout_attention: float = 0.0, attn_drop_rate: float = 0.0,
dropout_path: float = 0.2, drop_path_rate: float = 0.0,
norm_layer: Type[nn.Module] = nn.LayerNorm,
use_checkpoint: bool = False, use_checkpoint: bool = False,
sequential_self_attention: bool = False, sequential_self_attention: bool = False,
use_deformable_block: bool = False) -> None: use_deformable_block: bool = False,
**kwargs: Any) -> None:
# Call super constructor # Call super constructor
super(SwinTransformerV2CR, self).__init__() super(SwinTransformerV2CR, self).__init__()
# Save parameters # Save parameters
self.patch_size: int = patch_size self.patch_size: int = patch_size
self.input_resolution: Tuple[int, int] = input_resolution self.input_resolution: Tuple[int, int] = img_size
self.window_size: int = window_size self.window_size: int = window_size
# Init patch embedding # Init patch embedding
self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_channels, out_channels=embedding_channels, self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim,
patch_size=patch_size) patch_size=patch_size)
# Compute patch resolution # Compute patch resolution
patch_resolution: Tuple[int, int] = (input_resolution[0] // patch_size, input_resolution[1] // patch_size) patch_resolution: Tuple[int, int] = (img_size[0] // patch_size, img_size[1] // patch_size)
# Path dropout dependent on depth # Path dropout dependent on depth
dropout_path = torch.linspace(0., dropout_path, sum(depths)).tolist() drop_path_rate = torch.linspace(0., drop_path_rate, sum(depths)).tolist()
# Init stages # Init stages
self.stages: nn.ModuleList = nn.ModuleList() self.stages: nn.ModuleList = nn.ModuleList()
for index, (depth, number_of_head) in enumerate(zip(depths, number_of_heads)): for index, (depth, number_of_head) in enumerate(zip(depths, num_heads)):
self.stages.append( self.stages.append(
SwinTransformerStage( SwinTransformerStage(
in_channels=embedding_channels * (2 ** max(index - 1, 0)), in_channels=embed_dim * (2 ** max(index - 1, 0)),
depth=depth, depth=depth,
downscale=index != 0, downscale=index != 0,
input_resolution=(patch_resolution[0] // (2 ** max(index - 1, 0)), input_resolution=(patch_resolution[0] // (2 ** max(index - 1, 0)),
patch_resolution[1] // (2 ** max(index - 1, 0))), patch_resolution[1] // (2 ** max(index - 1, 0))),
number_of_heads=number_of_head, number_of_heads=number_of_head,
window_size=window_size, window_size=window_size,
ff_feature_ratio=ff_feature_ratio, ff_feature_ratio=mlp_ratio,
dropout=dropout, dropout=drop_rate,
dropout_attention=dropout_attention, dropout_attention=attn_drop_rate,
dropout_path=dropout_path[sum(depths[:index]):sum(depths[:index + 1])], dropout_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
use_checkpoint=use_checkpoint, use_checkpoint=use_checkpoint,
sequential_self_attention=sequential_self_attention, sequential_self_attention=sequential_self_attention,
use_deformable_block=use_deformable_block and (index > 0) use_deformable_block=use_deformable_block and (index > 0)
)) ))
# Init final adaptive average pooling, and classification head # Init final adaptive average pooling, and classification head
self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1) self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1)
self.head: nn.Module = nn.Linear(in_features=embedding_channels * (2 ** len(depths) - 1), self.head: nn.Module = nn.Linear(in_features=embed_dim * (2 ** len(depths) - 1),
out_features=num_classes) out_features=num_classes)
def update_resolution(self, def update_resolution(self,

Loading…
Cancel
Save