From 90dc74c450a0ec671af0e7f73d6b4a7b5396a7af Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sat, 19 Feb 2022 22:12:11 +0100 Subject: [PATCH] Add code from https://github.com/ChristophReich1996/Swin-Transformer-V2 and change docstring style to match timm --- timm/models/swin_transformer_v2.py | 927 +++++++++++++++++++++++++++++ 1 file changed, 927 insertions(+) create mode 100644 timm/models/swin_transformer_v2.py diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py new file mode 100644 index 00000000..c146ecc8 --- /dev/null +++ b/timm/models/swin_transformer_v2.py @@ -0,0 +1,927 @@ +""" Swin Transformer V2 +A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` + - https://arxiv.org/pdf/2111.09883 + +Code adapted from https://github.com/ChristophReich1996/Swin-Transformer-V2, original copyright/license info below + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" +# -------------------------------------------------------- +# Swin Transformer V2 reimplementation +# Copyright (c) 2021 Christoph Reich +# Licensed under The MIT License [see LICENSE for details] +# Written by Christoph Reich +# -------------------------------------------------------- +from typing import Tuple, Optional, List, Union, Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from .layers import DropPath, Mlp + + +def bchw_to_bhwc(input: torch.Tensor) -> torch.Tensor: + """ Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). + + Args: + input (torch.Tensor): Input tensor of the shape (B, C, H, W) + + Returns: + output (torch.Tensor): Permuted tensor of the shape (B, H, W, C) + """ + output: torch.Tensor = input.permute(0, 2, 3, 1) + return output + + +def bhwc_to_bchw(input: torch.Tensor) -> torch.Tensor: + """ Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W). + + Args: + input (torch.Tensor): Input tensor of the shape (B, H, W, C) + + Returns: + output (torch.Tensor): Permuted tensor of the shape (B, C, H, W) + """ + output: torch.Tensor = input.permute(0, 3, 1, 2) + return output + + +def unfold(input: torch.Tensor, + window_size: int) -> torch.Tensor: + """ Unfolds (non-overlapping) a given feature map by the given window size (stride = window size). + + Args: + input (torch.Tensor): Input feature map of the shape (B, C, H, W) + window_size (int): Window size to be applied + + Returns: + output (torch.Tensor): Unfolded tensor of the shape [B * windows, C, window size, window size] + """ + # Get original shape + _, channels, height, width = input.shape # type: int, int, int, int + # Unfold input + output: torch.Tensor = input.unfold(dimension=3, size=window_size, step=window_size) \ + .unfold(dimension=2, size=window_size, step=window_size) + # Reshape to (B * windows, C, window size, window size) + output: torch.Tensor = output.permute(0, 2, 3, 1, 5, 4).reshape(-1, channels, window_size, window_size) + return output + + +def fold(input: torch.Tensor, + window_size: int, + height: int, + width: int) -> torch.Tensor: + """ Folds a tensor of windows again to a 4D feature map. + + Args: + input (torch.Tensor): Input feature map of the shape (B, C, H, W) + window_size (int): Window size of the unfold operation + height (int): Height of the feature map + width (int): Width of the feature map + + Returns: + output (torch.Tensor): Folded output tensor of the shape (B, C, H, W) + """ + # Get channels of windows + channels: int = input.shape[1] + # Get original batch size + batch_size: int = int(input.shape[0] // (height * width // window_size // window_size)) + # Reshape input to (B, C, H, W) + output: torch.Tensor = input.view(batch_size, height // window_size, width // window_size, channels, + window_size, window_size) + output: torch.Tensor = output.permute(0, 3, 1, 4, 2, 5).reshape(batch_size, channels, height, width) + return output + + +class WindowMultiHeadAttention(nn.Module): + r""" This class implements window-based Multi-Head-Attention with log-spaced continuous position bias. + + Args: + in_features (int): Number of input features + window_size (int): Window size + number_of_heads (int): Number of attention heads + dropout_attention (float): Dropout rate of attention map + dropout_projection (float): Dropout rate after projection + meta_network_hidden_features (int): Number of hidden features in the two layer MLP meta network + sequential_self_attention (bool): If true sequential self-attention is performed + """ + + def __init__(self, + in_features: int, + window_size: int, + number_of_heads: int, + dropout_attention: float = 0., + dropout_projection: float = 0., + meta_network_hidden_features: int = 256, + sequential_self_attention: bool = False) -> None: + # Call super constructor + super(WindowMultiHeadAttention, self).__init__() + # Check parameter + assert (in_features % number_of_heads) == 0, \ + "The number of input features (in_features) are not divisible by the number of heads (number_of_heads)." + # Save parameters + self.in_features: int = in_features + self.window_size: int = window_size + self.number_of_heads: int = number_of_heads + self.sequential_self_attention: bool = sequential_self_attention + # Init query, key and value mapping as a single layer + self.mapping_qkv: nn.Module = nn.Linear(in_features=in_features, out_features=in_features * 3, bias=True) + # Init attention dropout + self.attention_dropout: nn.Module = nn.Dropout(dropout_attention) + # Init projection mapping + self.projection: nn.Module = nn.Linear(in_features=in_features, out_features=in_features, bias=True) + # Init projection dropout + self.projection_dropout: nn.Module = nn.Dropout(dropout_projection) + # Init meta network for positional encodings + self.meta_network: nn.Module = nn.Sequential( + nn.Linear(in_features=2, out_features=meta_network_hidden_features, bias=True), + nn.ReLU(inplace=True), + nn.Linear(in_features=meta_network_hidden_features, out_features=number_of_heads, bias=True)) + # Init tau + self.register_parameter("tau", torch.nn.Parameter(torch.ones(1, number_of_heads, 1, 1))) + # Init pair-wise relative positions (log-spaced) + self.__make_pair_wise_relative_positions() + + def __make_pair_wise_relative_positions(self) -> None: + """ Method initializes the pair-wise relative positions to compute the positional biases.""" + indexes: torch.Tensor = torch.arange(self.window_size, device=self.tau.device) + coordinates: torch.Tensor = torch.stack(torch.meshgrid([indexes, indexes]), dim=0) + coordinates: torch.Tensor = torch.flatten(coordinates, start_dim=1) + relative_coordinates: torch.Tensor = coordinates[:, :, None] - coordinates[:, None, :] + relative_coordinates: torch.Tensor = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float() + relative_coordinates_log: torch.Tensor = torch.sign(relative_coordinates) \ + * torch.log(1. + relative_coordinates.abs()) + self.register_buffer("relative_coordinates_log", relative_coordinates_log) + + def update_resolution(self, + new_window_size: int, + **kwargs: Any) -> None: + """ Method updates the window size and so the pair-wise relative positions + + Args: + new_window_size (int): New window size + kwargs (Any): Unused + """ + # Set new window size + self.window_size: int = new_window_size + # Make new pair-wise relative positions + self.__make_pair_wise_relative_positions() + + def __get_relative_positional_encodings(self) -> torch.Tensor: + """ Method computes the relative positional encodings + + Returns: + relative_position_bias (torch.Tensor): Relative positional encodings + (1, number of heads, window size ** 2, window size ** 2) + """ + relative_position_bias: torch.Tensor = self.meta_network(self.relative_coordinates_log) + relative_position_bias: torch.Tensor = relative_position_bias.permute(1, 0) + relative_position_bias: torch.Tensor = relative_position_bias.reshape(self.number_of_heads, + self.window_size * self.window_size, + self.window_size * self.window_size) + relative_position_bias: torch.Tensor = relative_position_bias.unsqueeze(0) + return relative_position_bias + + def __self_attention(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + batch_size_windows: int, + tokens: int, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ This function performs standard (non-sequential) scaled cosine self-attention. + + Args: + query (torch.Tensor): Query tensor of the shape [B * windows, heads, tokens, C // heads] + key (torch.Tensor): Key tensor of the shape [B * windows, heads, tokens, C // heads] + value (torch.Tensor): Value tensor of the shape (B * windows, heads, tokens, C // heads) + batch_size_windows (int): Size of the first dimension of the input tensor batch size * windows + tokens (int): Number of tokens in the input + mask (Optional[torch.Tensor]): Attention mask for the shift case + + Returns: + output (torch.Tensor): Output feature map of the shape [B * windows, tokens, C] + """ + # Compute attention map with scaled cosine attention + attention_map: torch.Tensor = torch.einsum("bhqd, bhkd -> bhqk", query, key) \ + / torch.maximum(torch.norm(query, dim=-1, keepdim=True) + * torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1), + torch.tensor(1e-06, device=query.device, dtype=query.dtype)) + attention_map: torch.Tensor = attention_map / self.tau.clamp(min=0.01) + # Apply relative positional encodings + attention_map: torch.Tensor = attention_map + self.__get_relative_positional_encodings() + # Apply mask if utilized + if mask is not None: + number_of_windows: int = mask.shape[0] + attention_map: torch.Tensor = attention_map.view(batch_size_windows // number_of_windows, number_of_windows, + self.number_of_heads, tokens, tokens) + attention_map: torch.Tensor = attention_map + mask.unsqueeze(1).unsqueeze(0) + attention_map: torch.Tensor = attention_map.view(-1, self.number_of_heads, tokens, tokens) + attention_map: torch.Tensor = attention_map.softmax(dim=-1) + # Perform attention dropout + attention_map: torch.Tensor = self.attention_dropout(attention_map) + # Apply attention map and reshape + output: torch.Tensor = torch.einsum("bhal, bhlv -> bhav", attention_map, value) + output: torch.Tensor = output.permute(0, 2, 1, 3).reshape(batch_size_windows, tokens, -1) + return output + + def __sequential_self_attention(self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + batch_size_windows: int, + tokens: int, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ This function performs sequential scaled cosine self-attention. + + Args: + query (torch.Tensor): Query tensor of the shape [B * windows, heads, tokens, C // heads] + key (torch.Tensor): Key tensor of the shape [B * windows, heads, tokens, C // heads] + value (torch.Tensor): Value tensor of the shape (B * windows, heads, tokens, C // heads) + batch_size_windows (int): Size of the first dimension of the input tensor batch size * windows + tokens (int): Number of tokens in the input + mask (Optional[torch.Tensor]): Attention mask for the shift case + + Returns: + output (torch.Tensor): Output feature map of the shape [B * windows, tokens, C] + """ + # Init output tensor + output: torch.Tensor = torch.ones_like(query) + # Compute relative positional encodings fist + relative_position_bias: torch.Tensor = self.__get_relative_positional_encodings() + # Iterate over query and key tokens + for token_index_query in range(tokens): + # Compute attention map with scaled cosine attention + attention_map: torch.Tensor = \ + torch.einsum("bhd, bhkd -> bhk", query[:, :, token_index_query], key) \ + / torch.maximum(torch.norm(query[:, :, token_index_query], dim=-1, keepdim=True) + * torch.norm(key, dim=-1, keepdim=False), + torch.tensor(1e-06, device=query.device, dtype=query.dtype)) + attention_map: torch.Tensor = attention_map / self.tau.clamp(min=0.01)[..., 0] + # Apply positional encodings + attention_map: torch.Tensor = attention_map + relative_position_bias[..., token_index_query, :] + # Apply mask if utilized + if mask is not None: + number_of_windows: int = mask.shape[0] + attention_map: torch.Tensor = attention_map.view(batch_size_windows // number_of_windows, + number_of_windows, self.number_of_heads, 1, + tokens) + attention_map: torch.Tensor = attention_map \ + + mask.unsqueeze(1).unsqueeze(0)[..., token_index_query, :].unsqueeze(3) + attention_map: torch.Tensor = attention_map.view(-1, self.number_of_heads, tokens) + attention_map: torch.Tensor = attention_map.softmax(dim=-1) + # Perform attention dropout + attention_map: torch.Tensor = self.attention_dropout(attention_map) + # Apply attention map and reshape + output[:, :, token_index_query] = torch.einsum("bhl, bhlv -> bhv", attention_map, value) + output: torch.Tensor = output.permute(0, 2, 1, 3).reshape(batch_size_windows, tokens, -1) + return output + + def forward(self, + input: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ Forward pass. + + Args: + input (torch.Tensor): Input tensor of the shape (B * windows, C, H, W) + mask (Optional[torch.Tensor]): Attention mask for the shift case + + Returns: + output (torch.Tensor): Output tensor of the shape [B * windows, C, H, W] + """ + # Save original shape + batch_size_windows, channels, height, width = input.shape # type: int, int, int, int + tokens: int = height * width + # Reshape input to (B * windows, tokens (height * width), C) + input: torch.Tensor = input.reshape(batch_size_windows, channels, tokens).permute(0, 2, 1) + # Perform query, key, and value mapping + query_key_value: torch.Tensor = self.mapping_qkv(input) + query_key_value: torch.Tensor = query_key_value.view(batch_size_windows, tokens, 3, self.number_of_heads, + channels // self.number_of_heads).permute(2, 0, 3, 1, 4) + query, key, value = query_key_value[0], query_key_value[1], query_key_value[2] + # Perform attention + if self.sequential_self_attention: + output: torch.Tensor = self.__sequential_self_attention(query=query, key=key, value=value, + batch_size_windows=batch_size_windows, + tokens=tokens, + mask=mask) + else: + output: torch.Tensor = self.__self_attention(query=query, key=key, value=value, + batch_size_windows=batch_size_windows, tokens=tokens, + mask=mask) + # Perform linear mapping and dropout + output: torch.Tensor = self.projection_dropout(self.projection(output)) + # Reshape output to original shape [B * windows, C, H, W] + output: torch.Tensor = output.permute(0, 2, 1).view(batch_size_windows, channels, height, width) + return output + + +class SwinTransformerBlock(nn.Module): + r""" This class implements the Swin transformer block. + + Args: + in_channels (int): Number of input channels + input_resolution (Tuple[int, int]): Input resolution + number_of_heads (int): Number of attention heads to be utilized + window_size (int): Window size to be utilized + shift_size (int): Shifting size to be used + ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels + dropout (float): Dropout in input mapping + dropout_attention (float): Dropout rate of attention map + dropout_path (float): Dropout in main path + sequential_self_attention (bool): If true sequential self-attention is performed + """ + + def __init__(self, + in_channels: int, + input_resolution: Tuple[int, int], + number_of_heads: int, + window_size: int = 7, + shift_size: int = 0, + ff_feature_ratio: int = 4, + dropout: float = 0.0, + dropout_attention: float = 0.0, + dropout_path: float = 0.0, + sequential_self_attention: bool = False) -> None: + # Call super constructor + super(SwinTransformerBlock, self).__init__() + # Save parameters + self.in_channels: int = in_channels + self.input_resolution: Tuple[int, int] = input_resolution + # Catch case if resolution is smaller than the window size + if min(self.input_resolution) <= window_size: + self.window_size: int = min(self.input_resolution) + self.shift_size: int = 0 + self.make_windows: bool = False + else: + self.window_size: int = window_size + self.shift_size: int = shift_size + self.make_windows: bool = True + # Init normalization layers + self.normalization_1: nn.Module = nn.LayerNorm(normalized_shape=in_channels) + self.normalization_2: nn.Module = nn.LayerNorm(normalized_shape=in_channels) + # Init window attention module + self.window_attention: WindowMultiHeadAttention = WindowMultiHeadAttention( + in_features=in_channels, + window_size=self.window_size, + number_of_heads=number_of_heads, + dropout_attention=dropout_attention, + dropout_projection=dropout, + sequential_self_attention=sequential_self_attention) + # Init dropout layer + self.dropout: nn.Module = DropPath(drop_prob=dropout_path) if dropout_path > 0. else nn.Identity() + # Init feed-forward network + self.feed_forward_network: nn.Module = Mlp(in_features=in_channels, + hidden_features=int(in_channels * ff_feature_ratio), + drop=dropout, + out_features=in_channels) + # Make attention mask + self.__make_attention_mask() + + def __make_attention_mask(self) -> None: + """ Method generates the attention mask used in shift case. """ + # Make masks for shift case + if self.shift_size > 0: + height, width = self.input_resolution # type: int, int + mask: torch.Tensor = torch.zeros(height, width, device=self.window_attention.tau.device) + height_slices: Tuple = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + width_slices: Tuple = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + counter: int = 0 + for height_slice in height_slices: + for width_slice in width_slices: + mask[height_slice, width_slice] = counter + counter += 1 + mask_windows: torch.Tensor = unfold(mask[None, None], self.window_size) + mask_windows: torch.Tensor = mask_windows.reshape(-1, self.window_size * self.window_size) + attention_mask: Optional[torch.Tensor] = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attention_mask: Optional[torch.Tensor] = attention_mask.masked_fill(attention_mask != 0, float(-100.0)) + attention_mask: Optional[torch.Tensor] = attention_mask.masked_fill(attention_mask == 0, float(0.0)) + else: + attention_mask: Optional[torch.Tensor] = None + # Save mask + self.register_buffer("attention_mask", attention_mask) + + def update_resolution(self, + new_window_size: int, + new_input_resolution: Tuple[int, int]) -> None: + """ Method updates the image resolution to be processed and window size and so the pair-wise relative positions. + + Args: + new_window_size (int): New window size + new_input_resolution (Tuple[int, int]): New input resolution + """ + # Update input resolution + self.input_resolution: Tuple[int, int] = new_input_resolution + # Catch case if resolution is smaller than the window size + if min(self.input_resolution) <= new_window_size: + self.window_size: int = min(self.input_resolution) + self.shift_size: int = 0 + self.make_windows: bool = False + else: + self.window_size: int = new_window_size + self.shift_size: int = self.shift_size + self.make_windows: bool = True + # Update attention mask + self.__make_attention_mask() + # Update attention module + self.window_attention.update_resolution(new_window_size=new_window_size) + + def forward(self, + input: torch.Tensor) -> torch.Tensor: + """ Forward pass. + + Args: + input (torch.Tensor): Input tensor of the shape [B, C, H, W] + + Returns: + output (torch.Tensor): Output tensor of the shape [B, C, H, W] + """ + # Save shape + batch_size, channels, height, width = input.shape # type: int, int, int, int + # Shift input if utilized + if self.shift_size > 0: + output_shift: torch.Tensor = torch.roll(input=input, shifts=(-self.shift_size, -self.shift_size), + dims=(-1, -2)) + else: + output_shift: torch.Tensor = input + # Make patches + output_patches: torch.Tensor = unfold(input=output_shift, window_size=self.window_size) \ + if self.make_windows else output_shift + # Perform window attention + output_attention: torch.Tensor = self.window_attention(output_patches, mask=self.attention_mask) + # Merge patches + output_merge: torch.Tensor = fold(input=output_attention, window_size=self.window_size, height=height, + width=width) if self.make_windows else output_attention + # Reverse shift if utilized + if self.shift_size > 0: + output_shift: torch.Tensor = torch.roll(input=output_merge, shifts=(self.shift_size, self.shift_size), + dims=(-1, -2)) + else: + output_shift: torch.Tensor = output_merge + # Perform normalization + output_normalize: torch.Tensor = self.normalization_1(output_shift.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + # Skip connection + output_skip: torch.Tensor = self.dropout(output_normalize) + input + # Feed forward network, normalization and skip connection + output_feed_forward: torch.Tensor = self.feed_forward_network( + output_skip.view(batch_size, channels, -1).permute(0, 2, 1)).permute(0, 2, 1) + output_feed_forward: torch.Tensor = output_feed_forward.view(batch_size, channels, height, width) + output_normalize: torch.Tensor = bhwc_to_bchw(self.normalization_2(bchw_to_bhwc(output_feed_forward))) + output: torch.Tensor = output_skip + self.dropout(output_normalize) + return output + + +class DeformableSwinTransformerBlock(SwinTransformerBlock): + r""" This class implements a deformable version of the Swin Transformer block. + Inspired by: https://arxiv.org/pdf/2201.00520 + + Args: + in_channels (int): Number of input channels + input_resolution (Tuple[int, int]): Input resolution + number_of_heads (int): Number of attention heads to be utilized + window_size (int): Window size to be utilized + shift_size (int): Shifting size to be used + ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels + dropout (float): Dropout in input mapping + dropout_attention (float): Dropout rate of attention map + dropout_path (float): Dropout in main path + sequential_self_attention (bool): If true sequential self-attention is performed + offset_downscale_factor (int): Downscale factor of offset network + """ + + def __init__(self, + in_channels: int, + input_resolution: Tuple[int, int], + number_of_heads: int, + window_size: int = 7, + shift_size: int = 0, + ff_feature_ratio: int = 4, + dropout: float = 0.0, + dropout_attention: float = 0.0, + dropout_path: float = 0.0, + sequential_self_attention: bool = False, + offset_downscale_factor: int = 2) -> None: + # Call super constructor + super(DeformableSwinTransformerBlock, self).__init__( + in_channels=in_channels, + input_resolution=input_resolution, + number_of_heads=number_of_heads, + window_size=window_size, + shift_size=shift_size, + ff_feature_ratio=ff_feature_ratio, + dropout=dropout, + dropout_attention=dropout_attention, + dropout_path=dropout_path, + sequential_self_attention=sequential_self_attention + ) + # Save parameter + self.offset_downscale_factor: int = offset_downscale_factor + self.number_of_heads: int = number_of_heads + # Make default offsets + self.__make_default_offsets() + # Init offset network + self.offset_network: nn.Module = nn.Sequential( + nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=5, stride=offset_downscale_factor, + padding=3, groups=in_channels, bias=True), + nn.GELU(), + nn.Conv2d(in_channels=in_channels, out_channels=2 * self.number_of_heads, kernel_size=1, stride=1, + padding=0, bias=True) + ) + + def __make_default_offsets(self) -> None: + """ Method generates the default sampling grid (inspired by kornia) """ + # Init x and y coordinates + x: torch.Tensor = torch.linspace(0, self.input_resolution[1] - 1, self.input_resolution[1], + device=self.window_attention.tau.device) + y: torch.Tensor = torch.linspace(0, self.input_resolution[0] - 1, self.input_resolution[0], + device=self.window_attention.tau.device) + # Normalize coordinates to a range of [-1, 1] + x: torch.Tensor = (x / (self.input_resolution[1] - 1) - 0.5) * 2 + y: torch.Tensor = (y / (self.input_resolution[0] - 1) - 0.5) * 2 + # Make grid [2, height, width] + grid: torch.Tensor = torch.stack(torch.meshgrid([x, y])).transpose(1, 2) + # Reshape grid to [1, height, width, 2] + grid: torch.Tensor = grid.unsqueeze(dim=0).permute(0, 2, 3, 1) + # Register in module + self.register_buffer("default_grid", grid) + + def update_resolution(self, + new_window_size: int, + new_input_resolution: Tuple[int, int]) -> None: + """ Method updates the window size and so the pair-wise relative positions. + + Args: + new_window_size (int): New window size + new_input_resolution (Tuple[int, int]): New input resolution + """ + # Update resolution and window size + super(DeformableSwinTransformerBlock, self).update_resolution(new_window_size=new_window_size, + new_input_resolution=new_input_resolution) + # Update default sampling grid + self.__make_default_offsets() + + def forward(self, + input: torch.Tensor) -> torch.Tensor: + """ Forward pass + Args: + input (torch.Tensor): Input tensor of the shape [B, C, H, W] + + Returns: + output (torch.Tensor): Output tensor of the shape [B, C, H, W] + """ + # Get input shape + batch_size, channels, height, width = input.shape + # Compute offsets of the shape [batch size, 2, height / r, width / r] + offsets: torch.Tensor = self.offset_network(input) + # Upscale offsets to the shape [batch size, 2 * number of heads, height, width] + offsets: torch.Tensor = F.interpolate(input=offsets, + size=(height, width), mode="bilinear", align_corners=True) + # Reshape offsets to [batch size, number of heads, height, width, 2] + offsets: torch.Tensor = offsets.reshape(batch_size, -1, 2, height, width).permute(0, 1, 3, 4, 2) + # Flatten batch size and number of heads and apply tanh + offsets: torch.Tensor = offsets.view(-1, height, width, 2).tanh() + # Cast offset grid to input data type + if input.dtype != self.default_grid.dtype: + self.default_grid = self.default_grid.type(input.dtype) + # Construct offset grid + offset_grid: torch.Tensor = self.default_grid.repeat_interleave(repeats=offsets.shape[0], dim=0) + offsets + # Reshape input to [batch size * number of heads, channels / number of heads, height, width] + input: torch.Tensor = input.view(batch_size, self.number_of_heads, channels // self.number_of_heads, height, + width).flatten(start_dim=0, end_dim=1) + # Apply sampling grid + input_resampled: torch.Tensor = F.grid_sample(input=input, grid=offset_grid.clip(min=-1, max=1), + mode="bilinear", align_corners=True, padding_mode="reflection") + # Reshape resampled tensor again to [batch size, channels, height, width] + input_resampled: torch.Tensor = input_resampled.view(batch_size, channels, height, width) + output: torch.Tensor = super(DeformableSwinTransformerBlock, self).forward(input=input_resampled) + return output + + +class PatchMerging(nn.Module): + """ This class implements the patch merging as a strided convolution with a normalization before. + + Args: + in_channels (int): Number of input channels + """ + + def __init__(self, + in_channels: int) -> None: + # Call super constructor + super(PatchMerging, self).__init__() + # Init normalization + self.normalization: nn.Module = nn.LayerNorm(normalized_shape=4 * in_channels) + # Init linear mapping + self.linear_mapping: nn.Module = nn.Linear(in_features=4 * in_channels, out_features=2 * in_channels, + bias=False) + + def forward(self, + input: torch.Tensor) -> torch.Tensor: + """ Forward pass. + + Args: + input (torch.Tensor): Input tensor of the shape [B, C, H, W] + + Returns: + output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] + """ + # Get original shape + batch_size, channels, height, width = input.shape # type: int, int, int, int + # Reshape input to [batch size, in channels, height, width] + input: torch.Tensor = bchw_to_bhwc(input) + # Unfold input + input: torch.Tensor = input.unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2) + input: torch.Tensor = input.reshape(batch_size, input.shape[1], input.shape[2], -1) + # Normalize input + input: torch.Tensor = self.normalization(input) + # Perform linear mapping + output: torch.Tensor = bhwc_to_bchw(self.linear_mapping(input)) + return output + + +class PatchEmbedding(nn.Module): + """ Module embeds a given image into patch embeddings. + + Args: + in_channels (int): Number of input channels + out_channels (int): Number of output channels + patch_size (int): Patch size to be utilized + image_size (int): Image size to be used + """ + + def __init__(self, + in_channels: int = 3, + out_channels: int = 96, + patch_size: int = 4) -> None: + # Call super constructor + super(PatchEmbedding, self).__init__() + # Save parameters + self.out_channels: int = out_channels + # Init linear embedding as a convolution + self.linear_embedding: nn.Module = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size)) + # Init layer normalization + self.normalization: nn.Module = nn.LayerNorm(normalized_shape=out_channels) + + def forward(self, + input: torch.Tensor) -> torch.Tensor: + """ Forward pass. + + Args: + input (torch.Tensor): Input image of the shape (B, C_in, H, W) + + Returns: + embedding (torch.Tensor): Embedding of the shape (B, C_out, H / patch size, W / patch size) + """ + # Perform linear embedding + embedding: torch.Tensor = self.linear_embedding(input) + # Perform normalization + embedding: torch.Tensor = bhwc_to_bchw(self.normalization(bchw_to_bhwc(embedding))) + return embedding + + +class SwinTransformerStage(nn.Module): + r""" This class implements a stage of the Swin transformer including multiple layers. + + Args: + in_channels (int): Number of input channels + depth (int): Depth of the stage (number of layers) + downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper) + input_resolution (Tuple[int, int]): Input resolution + number_of_heads (int): Number of attention heads to be utilized + window_size (int): Window size to be utilized + shift_size (int): Shifting size to be used + ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels + dropout (float): Dropout in input mapping + dropout_attention (float): Dropout rate of attention map + dropout_path (float): Dropout in main path + use_checkpoint (bool): If true checkpointing is utilized + sequential_self_attention (bool): If true sequential self-attention is performed + use_deformable_block (bool): If true deformable block is used + """ + + def __init__(self, + in_channels: int, + depth: int, + downscale: bool, + input_resolution: Tuple[int, int], + number_of_heads: int, + window_size: int = 7, + ff_feature_ratio: int = 4, + dropout: float = 0.0, + dropout_attention: float = 0.0, + dropout_path: Union[List[float], float] = 0.0, + use_checkpoint: bool = False, + sequential_self_attention: bool = False, + use_deformable_block: bool = False) -> None: + # Call super constructor + super(SwinTransformerStage, self).__init__() + # Save parameters + self.use_checkpoint: bool = use_checkpoint + self.downscale: bool = downscale + # Init downsampling + self.downsample: nn.Module = PatchMerging(in_channels=in_channels) if downscale else nn.Identity() + # Update resolution and channels + self.input_resolution: Tuple[int, int] = (input_resolution[0] // 2, input_resolution[1] // 2) \ + if downscale else input_resolution + in_channels = in_channels * 2 if downscale else in_channels + # Get block + block = DeformableSwinTransformerBlock if use_deformable_block else SwinTransformerBlock + # Init blocks + self.blocks: nn.ModuleList = nn.ModuleList([ + block(in_channels=in_channels, + input_resolution=self.input_resolution, + number_of_heads=number_of_heads, + window_size=window_size, + shift_size=0 if ((index % 2) == 0) else window_size // 2, + ff_feature_ratio=ff_feature_ratio, + dropout=dropout, + dropout_attention=dropout_attention, + dropout_path=dropout_path[index] if isinstance(dropout_path, list) else dropout_path, + sequential_self_attention=sequential_self_attention) + for index in range(depth)]) + + def update_resolution(self, + new_window_size: int, + new_input_resolution: Tuple[int, int]) -> None: + """ Method updates the resolution to utilize and the window size and so the pair-wise relative positions. + + Args: + new_window_size (int): New window size + new_input_resolution (Tuple[int, int]): New input resolution + """ + # Update resolution + self.input_resolution: Tuple[int, int] = (new_input_resolution[0] // 2, new_input_resolution[1] // 2) \ + if self.downscale else new_input_resolution + # Update resolution of each block + for block in self.blocks: # type: SwinTransformerBlock + block.update_resolution(new_window_size=new_window_size, new_input_resolution=self.input_resolution) + + def forward(self, + input: torch.Tensor) -> torch.Tensor: + """ Forward pass. + + Args: + input (torch.Tensor): Input tensor of the shape [B, C, H, W] + + Returns: + output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2] + """ + # Downscale input tensor + output: torch.Tensor = self.downsample(input) + # Forward pass of each block + for block in self.blocks: # type: nn.Module + # Perform checkpointing if utilized + if self.use_checkpoint: + output: torch.Tensor = checkpoint.checkpoint(block, output) + else: + output: torch.Tensor = block(output) + return output + + +class SwinTransformerV2(nn.Module): + r""" Swin Transformer V2 + A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` - + https://arxiv.org/pdf/2111.09883 + + Args: + in_channels (int): Number of input channels + depth (int): Depth of the stage (number of layers) + downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper) + input_resolution (Tuple[int, int]): Input resolution + number_of_heads (int): Number of attention heads to be utilized + num_classes (int): Number of output classes + window_size (int): Window size to be utilized + shift_size (int): Shifting size to be used + ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels + dropout (float): Dropout in input mapping + dropout_attention (float): Dropout rate of attention map + dropout_path (float): Dropout in main path + use_checkpoint (bool): If true checkpointing is utilized + sequential_self_attention (bool): If true sequential self-attention is performed + use_deformable_block (bool): If true deformable block is used + """ + + def __init__(self, + in_channels: int, + embedding_channels: int, + depths: Tuple[int, ...], + input_resolution: Tuple[int, int], + number_of_heads: Tuple[int, ...], + num_classes: int = 1000, + window_size: int = 7, + patch_size: int = 4, + ff_feature_ratio: int = 4, + dropout: float = 0.0, + dropout_attention: float = 0.0, + dropout_path: float = 0.2, + use_checkpoint: bool = False, + sequential_self_attention: bool = False, + use_deformable_block: bool = False) -> None: + # Call super constructor + super(SwinTransformerV2, self).__init__() + # Save parameters + self.patch_size: int = patch_size + self.input_resolution: Tuple[int, int] = input_resolution + self.window_size: int = window_size + # Init patch embedding + self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_channels, out_channels=embedding_channels, + patch_size=patch_size) + # Compute patch resolution + patch_resolution: Tuple[int, int] = (input_resolution[0] // patch_size, input_resolution[1] // patch_size) + # Path dropout dependent on depth + dropout_path = torch.linspace(0., dropout_path, sum(depths)).tolist() + # Init stages + self.stages: nn.ModuleList = nn.ModuleList() + for index, (depth, number_of_head) in enumerate(zip(depths, number_of_heads)): + self.stages.append( + SwinTransformerStage( + in_channels=embedding_channels * (2 ** max(index - 1, 0)), + depth=depth, + downscale=index != 0, + input_resolution=(patch_resolution[0] // (2 ** max(index - 1, 0)), + patch_resolution[1] // (2 ** max(index - 1, 0))), + number_of_heads=number_of_head, + window_size=window_size, + ff_feature_ratio=ff_feature_ratio, + dropout=dropout, + dropout_attention=dropout_attention, + dropout_path=dropout_path[sum(depths[:index]):sum(depths[:index + 1])], + use_checkpoint=use_checkpoint, + sequential_self_attention=sequential_self_attention, + use_deformable_block=use_deformable_block and (index > 0) + )) + # Init final adaptive average pooling, and classification head + self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1) + self.head: nn.Module = nn.Linear(in_features=embedding_channels * (2 ** len(depths) - 1), + out_features=num_classes) + + def update_resolution(self, + new_input_resolution: Optional[Tuple[int, int]] = None, + new_window_size: Optional[int] = None) -> None: + """ Method updates the image resolution to be processed and window size and so the pair-wise relative positions. + + Args: + new_window_size (Optional[int]): New window size if None current window size is used + new_input_resolution (Optional[Tuple[int, int]]): New input resolution if None current resolution is used + """ + # Check parameters + if new_input_resolution is None: + new_input_resolution = self.input_resolution + if new_window_size is None: + new_window_size = self.window_size + # Compute new patch resolution + new_patch_resolution: Tuple[int, int] = (new_input_resolution[0] // self.patch_size, + new_input_resolution[1] // self.patch_size) + # Update resolution of each stage + for index, stage in enumerate(self.stages): # type: int, SwinTransformerStage + stage.update_resolution(new_window_size=new_window_size, + new_input_resolution=(new_patch_resolution[0] // (2 ** max(index - 1, 0)), + new_patch_resolution[1] // (2 ** max(index - 1, 0)))) + + def forward_features(self, + input: torch.Tensor) -> List[torch.Tensor]: + """ Forward pass to extract feature maps of each stage. + + Args: + input (torch.Tensor): Input images of the shape (B, C, H, W) + + Returns: + features (List[torch.Tensor]): List of feature maps from each stage + """ + # Perform patch embedding + output: torch.Tensor = self.patch_embedding(input) + # Init list to store feature + features: List[torch.Tensor] = [] + # Forward pass of each stage + for stage in self.stages: + output: torch.Tensor = stage(output) + features.append(output) + return features + + def forward(self, + input: torch.Tensor) -> torch.Tensor: + """ Forward pass. + + Args: + input (torch.Tensor): Input images of the shape (B, C, H, W) + + Returns: + classification (torch.Tensor): Classification of the shape (B, num_classes) + """ + # Perform patch embedding + output: torch.Tensor = self.patch_embedding(input) + # Forward pass of each stage + for stage in self.stages: + output: torch.Tensor = stage(output) + # Perform average pooling + output: torch.Tensor = self.average_pool(output) + # Predict classification + classification: torch.Tensor = self.head(output) + return classification