You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/swin_transformer_v2.py

928 lines
46 KiB

""" Swin Transformer V2
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
- https://arxiv.org/pdf/2111.09883
Code adapted from https://github.com/ChristophReich1996/Swin-Transformer-V2, original copyright/license info below
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
"""
# --------------------------------------------------------
# Swin Transformer V2 reimplementation
# Copyright (c) 2021 Christoph Reich
# Licensed under The MIT License [see LICENSE for details]
# Written by Christoph Reich
# --------------------------------------------------------
from typing import Tuple, Optional, List, Union, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from .layers import DropPath, Mlp
def bchw_to_bhwc(input: torch.Tensor) -> torch.Tensor:
""" Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C).
Args:
input (torch.Tensor): Input tensor of the shape (B, C, H, W)
Returns:
output (torch.Tensor): Permuted tensor of the shape (B, H, W, C)
"""
output: torch.Tensor = input.permute(0, 2, 3, 1)
return output
def bhwc_to_bchw(input: torch.Tensor) -> torch.Tensor:
""" Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W).
Args:
input (torch.Tensor): Input tensor of the shape (B, H, W, C)
Returns:
output (torch.Tensor): Permuted tensor of the shape (B, C, H, W)
"""
output: torch.Tensor = input.permute(0, 3, 1, 2)
return output
def unfold(input: torch.Tensor,
window_size: int) -> torch.Tensor:
""" Unfolds (non-overlapping) a given feature map by the given window size (stride = window size).
Args:
input (torch.Tensor): Input feature map of the shape (B, C, H, W)
window_size (int): Window size to be applied
Returns:
output (torch.Tensor): Unfolded tensor of the shape [B * windows, C, window size, window size]
"""
# Get original shape
_, channels, height, width = input.shape # type: int, int, int, int
# Unfold input
output: torch.Tensor = input.unfold(dimension=3, size=window_size, step=window_size) \
.unfold(dimension=2, size=window_size, step=window_size)
# Reshape to (B * windows, C, window size, window size)
output: torch.Tensor = output.permute(0, 2, 3, 1, 5, 4).reshape(-1, channels, window_size, window_size)
return output
def fold(input: torch.Tensor,
window_size: int,
height: int,
width: int) -> torch.Tensor:
""" Folds a tensor of windows again to a 4D feature map.
Args:
input (torch.Tensor): Input feature map of the shape (B, C, H, W)
window_size (int): Window size of the unfold operation
height (int): Height of the feature map
width (int): Width of the feature map
Returns:
output (torch.Tensor): Folded output tensor of the shape (B, C, H, W)
"""
# Get channels of windows
channels: int = input.shape[1]
# Get original batch size
batch_size: int = int(input.shape[0] // (height * width // window_size // window_size))
# Reshape input to (B, C, H, W)
output: torch.Tensor = input.view(batch_size, height // window_size, width // window_size, channels,
window_size, window_size)
output: torch.Tensor = output.permute(0, 3, 1, 4, 2, 5).reshape(batch_size, channels, height, width)
return output
class WindowMultiHeadAttention(nn.Module):
r""" This class implements window-based Multi-Head-Attention with log-spaced continuous position bias.
Args:
in_features (int): Number of input features
window_size (int): Window size
number_of_heads (int): Number of attention heads
dropout_attention (float): Dropout rate of attention map
dropout_projection (float): Dropout rate after projection
meta_network_hidden_features (int): Number of hidden features in the two layer MLP meta network
sequential_self_attention (bool): If true sequential self-attention is performed
"""
def __init__(self,
in_features: int,
window_size: int,
number_of_heads: int,
dropout_attention: float = 0.,
dropout_projection: float = 0.,
meta_network_hidden_features: int = 256,
sequential_self_attention: bool = False) -> None:
# Call super constructor
super(WindowMultiHeadAttention, self).__init__()
# Check parameter
assert (in_features % number_of_heads) == 0, \
"The number of input features (in_features) are not divisible by the number of heads (number_of_heads)."
# Save parameters
self.in_features: int = in_features
self.window_size: int = window_size
self.number_of_heads: int = number_of_heads
self.sequential_self_attention: bool = sequential_self_attention
# Init query, key and value mapping as a single layer
self.mapping_qkv: nn.Module = nn.Linear(in_features=in_features, out_features=in_features * 3, bias=True)
# Init attention dropout
self.attention_dropout: nn.Module = nn.Dropout(dropout_attention)
# Init projection mapping
self.projection: nn.Module = nn.Linear(in_features=in_features, out_features=in_features, bias=True)
# Init projection dropout
self.projection_dropout: nn.Module = nn.Dropout(dropout_projection)
# Init meta network for positional encodings
self.meta_network: nn.Module = nn.Sequential(
nn.Linear(in_features=2, out_features=meta_network_hidden_features, bias=True),
nn.ReLU(inplace=True),
nn.Linear(in_features=meta_network_hidden_features, out_features=number_of_heads, bias=True))
# Init tau
self.register_parameter("tau", torch.nn.Parameter(torch.ones(1, number_of_heads, 1, 1)))
# Init pair-wise relative positions (log-spaced)
self.__make_pair_wise_relative_positions()
def __make_pair_wise_relative_positions(self) -> None:
""" Method initializes the pair-wise relative positions to compute the positional biases."""
indexes: torch.Tensor = torch.arange(self.window_size, device=self.tau.device)
coordinates: torch.Tensor = torch.stack(torch.meshgrid([indexes, indexes]), dim=0)
coordinates: torch.Tensor = torch.flatten(coordinates, start_dim=1)
relative_coordinates: torch.Tensor = coordinates[:, :, None] - coordinates[:, None, :]
relative_coordinates: torch.Tensor = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float()
relative_coordinates_log: torch.Tensor = torch.sign(relative_coordinates) \
* torch.log(1. + relative_coordinates.abs())
self.register_buffer("relative_coordinates_log", relative_coordinates_log)
def update_resolution(self,
new_window_size: int,
**kwargs: Any) -> None:
""" Method updates the window size and so the pair-wise relative positions
Args:
new_window_size (int): New window size
kwargs (Any): Unused
"""
# Set new window size
self.window_size: int = new_window_size
# Make new pair-wise relative positions
self.__make_pair_wise_relative_positions()
def __get_relative_positional_encodings(self) -> torch.Tensor:
""" Method computes the relative positional encodings
Returns:
relative_position_bias (torch.Tensor): Relative positional encodings
(1, number of heads, window size ** 2, window size ** 2)
"""
relative_position_bias: torch.Tensor = self.meta_network(self.relative_coordinates_log)
relative_position_bias: torch.Tensor = relative_position_bias.permute(1, 0)
relative_position_bias: torch.Tensor = relative_position_bias.reshape(self.number_of_heads,
self.window_size * self.window_size,
self.window_size * self.window_size)
relative_position_bias: torch.Tensor = relative_position_bias.unsqueeze(0)
return relative_position_bias
def __self_attention(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
batch_size_windows: int,
tokens: int,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
""" This function performs standard (non-sequential) scaled cosine self-attention.
Args:
query (torch.Tensor): Query tensor of the shape [B * windows, heads, tokens, C // heads]
key (torch.Tensor): Key tensor of the shape [B * windows, heads, tokens, C // heads]
value (torch.Tensor): Value tensor of the shape (B * windows, heads, tokens, C // heads)
batch_size_windows (int): Size of the first dimension of the input tensor batch size * windows
tokens (int): Number of tokens in the input
mask (Optional[torch.Tensor]): Attention mask for the shift case
Returns:
output (torch.Tensor): Output feature map of the shape [B * windows, tokens, C]
"""
# Compute attention map with scaled cosine attention
attention_map: torch.Tensor = torch.einsum("bhqd, bhkd -> bhqk", query, key) \
/ torch.maximum(torch.norm(query, dim=-1, keepdim=True)
* torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1),
torch.tensor(1e-06, device=query.device, dtype=query.dtype))
attention_map: torch.Tensor = attention_map / self.tau.clamp(min=0.01)
# Apply relative positional encodings
attention_map: torch.Tensor = attention_map + self.__get_relative_positional_encodings()
# Apply mask if utilized
if mask is not None:
number_of_windows: int = mask.shape[0]
attention_map: torch.Tensor = attention_map.view(batch_size_windows // number_of_windows, number_of_windows,
self.number_of_heads, tokens, tokens)
attention_map: torch.Tensor = attention_map + mask.unsqueeze(1).unsqueeze(0)
attention_map: torch.Tensor = attention_map.view(-1, self.number_of_heads, tokens, tokens)
attention_map: torch.Tensor = attention_map.softmax(dim=-1)
# Perform attention dropout
attention_map: torch.Tensor = self.attention_dropout(attention_map)
# Apply attention map and reshape
output: torch.Tensor = torch.einsum("bhal, bhlv -> bhav", attention_map, value)
output: torch.Tensor = output.permute(0, 2, 1, 3).reshape(batch_size_windows, tokens, -1)
return output
def __sequential_self_attention(self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
batch_size_windows: int,
tokens: int,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
""" This function performs sequential scaled cosine self-attention.
Args:
query (torch.Tensor): Query tensor of the shape [B * windows, heads, tokens, C // heads]
key (torch.Tensor): Key tensor of the shape [B * windows, heads, tokens, C // heads]
value (torch.Tensor): Value tensor of the shape (B * windows, heads, tokens, C // heads)
batch_size_windows (int): Size of the first dimension of the input tensor batch size * windows
tokens (int): Number of tokens in the input
mask (Optional[torch.Tensor]): Attention mask for the shift case
Returns:
output (torch.Tensor): Output feature map of the shape [B * windows, tokens, C]
"""
# Init output tensor
output: torch.Tensor = torch.ones_like(query)
# Compute relative positional encodings fist
relative_position_bias: torch.Tensor = self.__get_relative_positional_encodings()
# Iterate over query and key tokens
for token_index_query in range(tokens):
# Compute attention map with scaled cosine attention
attention_map: torch.Tensor = \
torch.einsum("bhd, bhkd -> bhk", query[:, :, token_index_query], key) \
/ torch.maximum(torch.norm(query[:, :, token_index_query], dim=-1, keepdim=True)
* torch.norm(key, dim=-1, keepdim=False),
torch.tensor(1e-06, device=query.device, dtype=query.dtype))
attention_map: torch.Tensor = attention_map / self.tau.clamp(min=0.01)[..., 0]
# Apply positional encodings
attention_map: torch.Tensor = attention_map + relative_position_bias[..., token_index_query, :]
# Apply mask if utilized
if mask is not None:
number_of_windows: int = mask.shape[0]
attention_map: torch.Tensor = attention_map.view(batch_size_windows // number_of_windows,
number_of_windows, self.number_of_heads, 1,
tokens)
attention_map: torch.Tensor = attention_map \
+ mask.unsqueeze(1).unsqueeze(0)[..., token_index_query, :].unsqueeze(3)
attention_map: torch.Tensor = attention_map.view(-1, self.number_of_heads, tokens)
attention_map: torch.Tensor = attention_map.softmax(dim=-1)
# Perform attention dropout
attention_map: torch.Tensor = self.attention_dropout(attention_map)
# Apply attention map and reshape
output[:, :, token_index_query] = torch.einsum("bhl, bhlv -> bhv", attention_map, value)
output: torch.Tensor = output.permute(0, 2, 1, 3).reshape(batch_size_windows, tokens, -1)
return output
def forward(self,
input: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
""" Forward pass.
Args:
input (torch.Tensor): Input tensor of the shape (B * windows, C, H, W)
mask (Optional[torch.Tensor]): Attention mask for the shift case
Returns:
output (torch.Tensor): Output tensor of the shape [B * windows, C, H, W]
"""
# Save original shape
batch_size_windows, channels, height, width = input.shape # type: int, int, int, int
tokens: int = height * width
# Reshape input to (B * windows, tokens (height * width), C)
input: torch.Tensor = input.reshape(batch_size_windows, channels, tokens).permute(0, 2, 1)
# Perform query, key, and value mapping
query_key_value: torch.Tensor = self.mapping_qkv(input)
query_key_value: torch.Tensor = query_key_value.view(batch_size_windows, tokens, 3, self.number_of_heads,
channels // self.number_of_heads).permute(2, 0, 3, 1, 4)
query, key, value = query_key_value[0], query_key_value[1], query_key_value[2]
# Perform attention
if self.sequential_self_attention:
output: torch.Tensor = self.__sequential_self_attention(query=query, key=key, value=value,
batch_size_windows=batch_size_windows,
tokens=tokens,
mask=mask)
else:
output: torch.Tensor = self.__self_attention(query=query, key=key, value=value,
batch_size_windows=batch_size_windows, tokens=tokens,
mask=mask)
# Perform linear mapping and dropout
output: torch.Tensor = self.projection_dropout(self.projection(output))
# Reshape output to original shape [B * windows, C, H, W]
output: torch.Tensor = output.permute(0, 2, 1).view(batch_size_windows, channels, height, width)
return output
class SwinTransformerBlock(nn.Module):
r""" This class implements the Swin transformer block.
Args:
in_channels (int): Number of input channels
input_resolution (Tuple[int, int]): Input resolution
number_of_heads (int): Number of attention heads to be utilized
window_size (int): Window size to be utilized
shift_size (int): Shifting size to be used
ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
dropout (float): Dropout in input mapping
dropout_attention (float): Dropout rate of attention map
dropout_path (float): Dropout in main path
sequential_self_attention (bool): If true sequential self-attention is performed
"""
def __init__(self,
in_channels: int,
input_resolution: Tuple[int, int],
number_of_heads: int,
window_size: int = 7,
shift_size: int = 0,
ff_feature_ratio: int = 4,
dropout: float = 0.0,
dropout_attention: float = 0.0,
dropout_path: float = 0.0,
sequential_self_attention: bool = False) -> None:
# Call super constructor
super(SwinTransformerBlock, self).__init__()
# Save parameters
self.in_channels: int = in_channels
self.input_resolution: Tuple[int, int] = input_resolution
# Catch case if resolution is smaller than the window size
if min(self.input_resolution) <= window_size:
self.window_size: int = min(self.input_resolution)
self.shift_size: int = 0
self.make_windows: bool = False
else:
self.window_size: int = window_size
self.shift_size: int = shift_size
self.make_windows: bool = True
# Init normalization layers
self.normalization_1: nn.Module = nn.LayerNorm(normalized_shape=in_channels)
self.normalization_2: nn.Module = nn.LayerNorm(normalized_shape=in_channels)
# Init window attention module
self.window_attention: WindowMultiHeadAttention = WindowMultiHeadAttention(
in_features=in_channels,
window_size=self.window_size,
number_of_heads=number_of_heads,
dropout_attention=dropout_attention,
dropout_projection=dropout,
sequential_self_attention=sequential_self_attention)
# Init dropout layer
self.dropout: nn.Module = DropPath(drop_prob=dropout_path) if dropout_path > 0. else nn.Identity()
# Init feed-forward network
self.feed_forward_network: nn.Module = Mlp(in_features=in_channels,
hidden_features=int(in_channels * ff_feature_ratio),
drop=dropout,
out_features=in_channels)
# Make attention mask
self.__make_attention_mask()
def __make_attention_mask(self) -> None:
""" Method generates the attention mask used in shift case. """
# Make masks for shift case
if self.shift_size > 0:
height, width = self.input_resolution # type: int, int
mask: torch.Tensor = torch.zeros(height, width, device=self.window_attention.tau.device)
height_slices: Tuple = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
width_slices: Tuple = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
counter: int = 0
for height_slice in height_slices:
for width_slice in width_slices:
mask[height_slice, width_slice] = counter
counter += 1
mask_windows: torch.Tensor = unfold(mask[None, None], self.window_size)
mask_windows: torch.Tensor = mask_windows.reshape(-1, self.window_size * self.window_size)
attention_mask: Optional[torch.Tensor] = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attention_mask: Optional[torch.Tensor] = attention_mask.masked_fill(attention_mask != 0, float(-100.0))
attention_mask: Optional[torch.Tensor] = attention_mask.masked_fill(attention_mask == 0, float(0.0))
else:
attention_mask: Optional[torch.Tensor] = None
# Save mask
self.register_buffer("attention_mask", attention_mask)
def update_resolution(self,
new_window_size: int,
new_input_resolution: Tuple[int, int]) -> None:
""" Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
Args:
new_window_size (int): New window size
new_input_resolution (Tuple[int, int]): New input resolution
"""
# Update input resolution
self.input_resolution: Tuple[int, int] = new_input_resolution
# Catch case if resolution is smaller than the window size
if min(self.input_resolution) <= new_window_size:
self.window_size: int = min(self.input_resolution)
self.shift_size: int = 0
self.make_windows: bool = False
else:
self.window_size: int = new_window_size
self.shift_size: int = self.shift_size
self.make_windows: bool = True
# Update attention mask
self.__make_attention_mask()
# Update attention module
self.window_attention.update_resolution(new_window_size=new_window_size)
def forward(self,
input: torch.Tensor) -> torch.Tensor:
""" Forward pass.
Args:
input (torch.Tensor): Input tensor of the shape [B, C, H, W]
Returns:
output (torch.Tensor): Output tensor of the shape [B, C, H, W]
"""
# Save shape
batch_size, channels, height, width = input.shape # type: int, int, int, int
# Shift input if utilized
if self.shift_size > 0:
output_shift: torch.Tensor = torch.roll(input=input, shifts=(-self.shift_size, -self.shift_size),
dims=(-1, -2))
else:
output_shift: torch.Tensor = input
# Make patches
output_patches: torch.Tensor = unfold(input=output_shift, window_size=self.window_size) \
if self.make_windows else output_shift
# Perform window attention
output_attention: torch.Tensor = self.window_attention(output_patches, mask=self.attention_mask)
# Merge patches
output_merge: torch.Tensor = fold(input=output_attention, window_size=self.window_size, height=height,
width=width) if self.make_windows else output_attention
# Reverse shift if utilized
if self.shift_size > 0:
output_shift: torch.Tensor = torch.roll(input=output_merge, shifts=(self.shift_size, self.shift_size),
dims=(-1, -2))
else:
output_shift: torch.Tensor = output_merge
# Perform normalization
output_normalize: torch.Tensor = self.normalization_1(output_shift.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
# Skip connection
output_skip: torch.Tensor = self.dropout(output_normalize) + input
# Feed forward network, normalization and skip connection
output_feed_forward: torch.Tensor = self.feed_forward_network(
output_skip.view(batch_size, channels, -1).permute(0, 2, 1)).permute(0, 2, 1)
output_feed_forward: torch.Tensor = output_feed_forward.view(batch_size, channels, height, width)
output_normalize: torch.Tensor = bhwc_to_bchw(self.normalization_2(bchw_to_bhwc(output_feed_forward)))
output: torch.Tensor = output_skip + self.dropout(output_normalize)
return output
class DeformableSwinTransformerBlock(SwinTransformerBlock):
r""" This class implements a deformable version of the Swin Transformer block.
Inspired by: https://arxiv.org/pdf/2201.00520
Args:
in_channels (int): Number of input channels
input_resolution (Tuple[int, int]): Input resolution
number_of_heads (int): Number of attention heads to be utilized
window_size (int): Window size to be utilized
shift_size (int): Shifting size to be used
ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
dropout (float): Dropout in input mapping
dropout_attention (float): Dropout rate of attention map
dropout_path (float): Dropout in main path
sequential_self_attention (bool): If true sequential self-attention is performed
offset_downscale_factor (int): Downscale factor of offset network
"""
def __init__(self,
in_channels: int,
input_resolution: Tuple[int, int],
number_of_heads: int,
window_size: int = 7,
shift_size: int = 0,
ff_feature_ratio: int = 4,
dropout: float = 0.0,
dropout_attention: float = 0.0,
dropout_path: float = 0.0,
sequential_self_attention: bool = False,
offset_downscale_factor: int = 2) -> None:
# Call super constructor
super(DeformableSwinTransformerBlock, self).__init__(
in_channels=in_channels,
input_resolution=input_resolution,
number_of_heads=number_of_heads,
window_size=window_size,
shift_size=shift_size,
ff_feature_ratio=ff_feature_ratio,
dropout=dropout,
dropout_attention=dropout_attention,
dropout_path=dropout_path,
sequential_self_attention=sequential_self_attention
)
# Save parameter
self.offset_downscale_factor: int = offset_downscale_factor
self.number_of_heads: int = number_of_heads
# Make default offsets
self.__make_default_offsets()
# Init offset network
self.offset_network: nn.Module = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=5, stride=offset_downscale_factor,
padding=3, groups=in_channels, bias=True),
nn.GELU(),
nn.Conv2d(in_channels=in_channels, out_channels=2 * self.number_of_heads, kernel_size=1, stride=1,
padding=0, bias=True)
)
def __make_default_offsets(self) -> None:
""" Method generates the default sampling grid (inspired by kornia) """
# Init x and y coordinates
x: torch.Tensor = torch.linspace(0, self.input_resolution[1] - 1, self.input_resolution[1],
device=self.window_attention.tau.device)
y: torch.Tensor = torch.linspace(0, self.input_resolution[0] - 1, self.input_resolution[0],
device=self.window_attention.tau.device)
# Normalize coordinates to a range of [-1, 1]
x: torch.Tensor = (x / (self.input_resolution[1] - 1) - 0.5) * 2
y: torch.Tensor = (y / (self.input_resolution[0] - 1) - 0.5) * 2
# Make grid [2, height, width]
grid: torch.Tensor = torch.stack(torch.meshgrid([x, y])).transpose(1, 2)
# Reshape grid to [1, height, width, 2]
grid: torch.Tensor = grid.unsqueeze(dim=0).permute(0, 2, 3, 1)
# Register in module
self.register_buffer("default_grid", grid)
def update_resolution(self,
new_window_size: int,
new_input_resolution: Tuple[int, int]) -> None:
""" Method updates the window size and so the pair-wise relative positions.
Args:
new_window_size (int): New window size
new_input_resolution (Tuple[int, int]): New input resolution
"""
# Update resolution and window size
super(DeformableSwinTransformerBlock, self).update_resolution(new_window_size=new_window_size,
new_input_resolution=new_input_resolution)
# Update default sampling grid
self.__make_default_offsets()
def forward(self,
input: torch.Tensor) -> torch.Tensor:
""" Forward pass
Args:
input (torch.Tensor): Input tensor of the shape [B, C, H, W]
Returns:
output (torch.Tensor): Output tensor of the shape [B, C, H, W]
"""
# Get input shape
batch_size, channels, height, width = input.shape
# Compute offsets of the shape [batch size, 2, height / r, width / r]
offsets: torch.Tensor = self.offset_network(input)
# Upscale offsets to the shape [batch size, 2 * number of heads, height, width]
offsets: torch.Tensor = F.interpolate(input=offsets,
size=(height, width), mode="bilinear", align_corners=True)
# Reshape offsets to [batch size, number of heads, height, width, 2]
offsets: torch.Tensor = offsets.reshape(batch_size, -1, 2, height, width).permute(0, 1, 3, 4, 2)
# Flatten batch size and number of heads and apply tanh
offsets: torch.Tensor = offsets.view(-1, height, width, 2).tanh()
# Cast offset grid to input data type
if input.dtype != self.default_grid.dtype:
self.default_grid = self.default_grid.type(input.dtype)
# Construct offset grid
offset_grid: torch.Tensor = self.default_grid.repeat_interleave(repeats=offsets.shape[0], dim=0) + offsets
# Reshape input to [batch size * number of heads, channels / number of heads, height, width]
input: torch.Tensor = input.view(batch_size, self.number_of_heads, channels // self.number_of_heads, height,
width).flatten(start_dim=0, end_dim=1)
# Apply sampling grid
input_resampled: torch.Tensor = F.grid_sample(input=input, grid=offset_grid.clip(min=-1, max=1),
mode="bilinear", align_corners=True, padding_mode="reflection")
# Reshape resampled tensor again to [batch size, channels, height, width]
input_resampled: torch.Tensor = input_resampled.view(batch_size, channels, height, width)
output: torch.Tensor = super(DeformableSwinTransformerBlock, self).forward(input=input_resampled)
return output
class PatchMerging(nn.Module):
""" This class implements the patch merging as a strided convolution with a normalization before.
Args:
in_channels (int): Number of input channels
"""
def __init__(self,
in_channels: int) -> None:
# Call super constructor
super(PatchMerging, self).__init__()
# Init normalization
self.normalization: nn.Module = nn.LayerNorm(normalized_shape=4 * in_channels)
# Init linear mapping
self.linear_mapping: nn.Module = nn.Linear(in_features=4 * in_channels, out_features=2 * in_channels,
bias=False)
def forward(self,
input: torch.Tensor) -> torch.Tensor:
""" Forward pass.
Args:
input (torch.Tensor): Input tensor of the shape [B, C, H, W]
Returns:
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
"""
# Get original shape
batch_size, channels, height, width = input.shape # type: int, int, int, int
# Reshape input to [batch size, in channels, height, width]
input: torch.Tensor = bchw_to_bhwc(input)
# Unfold input
input: torch.Tensor = input.unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2)
input: torch.Tensor = input.reshape(batch_size, input.shape[1], input.shape[2], -1)
# Normalize input
input: torch.Tensor = self.normalization(input)
# Perform linear mapping
output: torch.Tensor = bhwc_to_bchw(self.linear_mapping(input))
return output
class PatchEmbedding(nn.Module):
""" Module embeds a given image into patch embeddings.
Args:
in_channels (int): Number of input channels
out_channels (int): Number of output channels
patch_size (int): Patch size to be utilized
image_size (int): Image size to be used
"""
def __init__(self,
in_channels: int = 3,
out_channels: int = 96,
patch_size: int = 4) -> None:
# Call super constructor
super(PatchEmbedding, self).__init__()
# Save parameters
self.out_channels: int = out_channels
# Init linear embedding as a convolution
self.linear_embedding: nn.Module = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size))
# Init layer normalization
self.normalization: nn.Module = nn.LayerNorm(normalized_shape=out_channels)
def forward(self,
input: torch.Tensor) -> torch.Tensor:
""" Forward pass.
Args:
input (torch.Tensor): Input image of the shape (B, C_in, H, W)
Returns:
embedding (torch.Tensor): Embedding of the shape (B, C_out, H / patch size, W / patch size)
"""
# Perform linear embedding
embedding: torch.Tensor = self.linear_embedding(input)
# Perform normalization
embedding: torch.Tensor = bhwc_to_bchw(self.normalization(bchw_to_bhwc(embedding)))
return embedding
class SwinTransformerStage(nn.Module):
r""" This class implements a stage of the Swin transformer including multiple layers.
Args:
in_channels (int): Number of input channels
depth (int): Depth of the stage (number of layers)
downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper)
input_resolution (Tuple[int, int]): Input resolution
number_of_heads (int): Number of attention heads to be utilized
window_size (int): Window size to be utilized
shift_size (int): Shifting size to be used
ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
dropout (float): Dropout in input mapping
dropout_attention (float): Dropout rate of attention map
dropout_path (float): Dropout in main path
use_checkpoint (bool): If true checkpointing is utilized
sequential_self_attention (bool): If true sequential self-attention is performed
use_deformable_block (bool): If true deformable block is used
"""
def __init__(self,
in_channels: int,
depth: int,
downscale: bool,
input_resolution: Tuple[int, int],
number_of_heads: int,
window_size: int = 7,
ff_feature_ratio: int = 4,
dropout: float = 0.0,
dropout_attention: float = 0.0,
dropout_path: Union[List[float], float] = 0.0,
use_checkpoint: bool = False,
sequential_self_attention: bool = False,
use_deformable_block: bool = False) -> None:
# Call super constructor
super(SwinTransformerStage, self).__init__()
# Save parameters
self.use_checkpoint: bool = use_checkpoint
self.downscale: bool = downscale
# Init downsampling
self.downsample: nn.Module = PatchMerging(in_channels=in_channels) if downscale else nn.Identity()
# Update resolution and channels
self.input_resolution: Tuple[int, int] = (input_resolution[0] // 2, input_resolution[1] // 2) \
if downscale else input_resolution
in_channels = in_channels * 2 if downscale else in_channels
# Get block
block = DeformableSwinTransformerBlock if use_deformable_block else SwinTransformerBlock
# Init blocks
self.blocks: nn.ModuleList = nn.ModuleList([
block(in_channels=in_channels,
input_resolution=self.input_resolution,
number_of_heads=number_of_heads,
window_size=window_size,
shift_size=0 if ((index % 2) == 0) else window_size // 2,
ff_feature_ratio=ff_feature_ratio,
dropout=dropout,
dropout_attention=dropout_attention,
dropout_path=dropout_path[index] if isinstance(dropout_path, list) else dropout_path,
sequential_self_attention=sequential_self_attention)
for index in range(depth)])
def update_resolution(self,
new_window_size: int,
new_input_resolution: Tuple[int, int]) -> None:
""" Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
Args:
new_window_size (int): New window size
new_input_resolution (Tuple[int, int]): New input resolution
"""
# Update resolution
self.input_resolution: Tuple[int, int] = (new_input_resolution[0] // 2, new_input_resolution[1] // 2) \
if self.downscale else new_input_resolution
# Update resolution of each block
for block in self.blocks: # type: SwinTransformerBlock
block.update_resolution(new_window_size=new_window_size, new_input_resolution=self.input_resolution)
def forward(self,
input: torch.Tensor) -> torch.Tensor:
""" Forward pass.
Args:
input (torch.Tensor): Input tensor of the shape [B, C, H, W]
Returns:
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
"""
# Downscale input tensor
output: torch.Tensor = self.downsample(input)
# Forward pass of each block
for block in self.blocks: # type: nn.Module
# Perform checkpointing if utilized
if self.use_checkpoint:
output: torch.Tensor = checkpoint.checkpoint(block, output)
else:
output: torch.Tensor = block(output)
return output
class SwinTransformerV2(nn.Module):
r""" Swin Transformer V2
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` -
https://arxiv.org/pdf/2111.09883
Args:
in_channels (int): Number of input channels
depth (int): Depth of the stage (number of layers)
downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper)
input_resolution (Tuple[int, int]): Input resolution
number_of_heads (int): Number of attention heads to be utilized
num_classes (int): Number of output classes
window_size (int): Window size to be utilized
shift_size (int): Shifting size to be used
ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
dropout (float): Dropout in input mapping
dropout_attention (float): Dropout rate of attention map
dropout_path (float): Dropout in main path
use_checkpoint (bool): If true checkpointing is utilized
sequential_self_attention (bool): If true sequential self-attention is performed
use_deformable_block (bool): If true deformable block is used
"""
def __init__(self,
in_channels: int,
embedding_channels: int,
depths: Tuple[int, ...],
input_resolution: Tuple[int, int],
number_of_heads: Tuple[int, ...],
num_classes: int = 1000,
window_size: int = 7,
patch_size: int = 4,
ff_feature_ratio: int = 4,
dropout: float = 0.0,
dropout_attention: float = 0.0,
dropout_path: float = 0.2,
use_checkpoint: bool = False,
sequential_self_attention: bool = False,
use_deformable_block: bool = False) -> None:
# Call super constructor
super(SwinTransformerV2, self).__init__()
# Save parameters
self.patch_size: int = patch_size
self.input_resolution: Tuple[int, int] = input_resolution
self.window_size: int = window_size
# Init patch embedding
self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_channels, out_channels=embedding_channels,
patch_size=patch_size)
# Compute patch resolution
patch_resolution: Tuple[int, int] = (input_resolution[0] // patch_size, input_resolution[1] // patch_size)
# Path dropout dependent on depth
dropout_path = torch.linspace(0., dropout_path, sum(depths)).tolist()
# Init stages
self.stages: nn.ModuleList = nn.ModuleList()
for index, (depth, number_of_head) in enumerate(zip(depths, number_of_heads)):
self.stages.append(
SwinTransformerStage(
in_channels=embedding_channels * (2 ** max(index - 1, 0)),
depth=depth,
downscale=index != 0,
input_resolution=(patch_resolution[0] // (2 ** max(index - 1, 0)),
patch_resolution[1] // (2 ** max(index - 1, 0))),
number_of_heads=number_of_head,
window_size=window_size,
ff_feature_ratio=ff_feature_ratio,
dropout=dropout,
dropout_attention=dropout_attention,
dropout_path=dropout_path[sum(depths[:index]):sum(depths[:index + 1])],
use_checkpoint=use_checkpoint,
sequential_self_attention=sequential_self_attention,
use_deformable_block=use_deformable_block and (index > 0)
))
# Init final adaptive average pooling, and classification head
self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1)
self.head: nn.Module = nn.Linear(in_features=embedding_channels * (2 ** len(depths) - 1),
out_features=num_classes)
def update_resolution(self,
new_input_resolution: Optional[Tuple[int, int]] = None,
new_window_size: Optional[int] = None) -> None:
""" Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
Args:
new_window_size (Optional[int]): New window size if None current window size is used
new_input_resolution (Optional[Tuple[int, int]]): New input resolution if None current resolution is used
"""
# Check parameters
if new_input_resolution is None:
new_input_resolution = self.input_resolution
if new_window_size is None:
new_window_size = self.window_size
# Compute new patch resolution
new_patch_resolution: Tuple[int, int] = (new_input_resolution[0] // self.patch_size,
new_input_resolution[1] // self.patch_size)
# Update resolution of each stage
for index, stage in enumerate(self.stages): # type: int, SwinTransformerStage
stage.update_resolution(new_window_size=new_window_size,
new_input_resolution=(new_patch_resolution[0] // (2 ** max(index - 1, 0)),
new_patch_resolution[1] // (2 ** max(index - 1, 0))))
def forward_features(self,
input: torch.Tensor) -> List[torch.Tensor]:
""" Forward pass to extract feature maps of each stage.
Args:
input (torch.Tensor): Input images of the shape (B, C, H, W)
Returns:
features (List[torch.Tensor]): List of feature maps from each stage
"""
# Perform patch embedding
output: torch.Tensor = self.patch_embedding(input)
# Init list to store feature
features: List[torch.Tensor] = []
# Forward pass of each stage
for stage in self.stages:
output: torch.Tensor = stage(output)
features.append(output)
return features
def forward(self,
input: torch.Tensor) -> torch.Tensor:
""" Forward pass.
Args:
input (torch.Tensor): Input images of the shape (B, C, H, W)
Returns:
classification (torch.Tensor): Classification of the shape (B, num_classes)
"""
# Perform patch embedding
output: torch.Tensor = self.patch_embedding(input)
# Forward pass of each stage
for stage in self.stages:
output: torch.Tensor = stage(output)
# Perform average pooling
output: torch.Tensor = self.average_pool(output)
# Predict classification
classification: torch.Tensor = self.head(output)
return classification