You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
928 lines
46 KiB
928 lines
46 KiB
3 years ago
|
""" Swin Transformer V2
|
||
|
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
|
||
|
- https://arxiv.org/pdf/2111.09883
|
||
|
|
||
|
Code adapted from https://github.com/ChristophReich1996/Swin-Transformer-V2, original copyright/license info below
|
||
|
|
||
|
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
|
||
|
"""
|
||
|
# --------------------------------------------------------
|
||
|
# Swin Transformer V2 reimplementation
|
||
|
# Copyright (c) 2021 Christoph Reich
|
||
|
# Licensed under The MIT License [see LICENSE for details]
|
||
|
# Written by Christoph Reich
|
||
|
# --------------------------------------------------------
|
||
|
from typing import Tuple, Optional, List, Union, Any
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import torch.utils.checkpoint as checkpoint
|
||
|
|
||
|
from .layers import DropPath, Mlp
|
||
|
|
||
|
|
||
|
def bchw_to_bhwc(input: torch.Tensor) -> torch.Tensor:
|
||
|
""" Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C).
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): Input tensor of the shape (B, C, H, W)
|
||
|
|
||
|
Returns:
|
||
|
output (torch.Tensor): Permuted tensor of the shape (B, H, W, C)
|
||
|
"""
|
||
|
output: torch.Tensor = input.permute(0, 2, 3, 1)
|
||
|
return output
|
||
|
|
||
|
|
||
|
def bhwc_to_bchw(input: torch.Tensor) -> torch.Tensor:
|
||
|
""" Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W).
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): Input tensor of the shape (B, H, W, C)
|
||
|
|
||
|
Returns:
|
||
|
output (torch.Tensor): Permuted tensor of the shape (B, C, H, W)
|
||
|
"""
|
||
|
output: torch.Tensor = input.permute(0, 3, 1, 2)
|
||
|
return output
|
||
|
|
||
|
|
||
|
def unfold(input: torch.Tensor,
|
||
|
window_size: int) -> torch.Tensor:
|
||
|
""" Unfolds (non-overlapping) a given feature map by the given window size (stride = window size).
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): Input feature map of the shape (B, C, H, W)
|
||
|
window_size (int): Window size to be applied
|
||
|
|
||
|
Returns:
|
||
|
output (torch.Tensor): Unfolded tensor of the shape [B * windows, C, window size, window size]
|
||
|
"""
|
||
|
# Get original shape
|
||
|
_, channels, height, width = input.shape # type: int, int, int, int
|
||
|
# Unfold input
|
||
|
output: torch.Tensor = input.unfold(dimension=3, size=window_size, step=window_size) \
|
||
|
.unfold(dimension=2, size=window_size, step=window_size)
|
||
|
# Reshape to (B * windows, C, window size, window size)
|
||
|
output: torch.Tensor = output.permute(0, 2, 3, 1, 5, 4).reshape(-1, channels, window_size, window_size)
|
||
|
return output
|
||
|
|
||
|
|
||
|
def fold(input: torch.Tensor,
|
||
|
window_size: int,
|
||
|
height: int,
|
||
|
width: int) -> torch.Tensor:
|
||
|
""" Folds a tensor of windows again to a 4D feature map.
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): Input feature map of the shape (B, C, H, W)
|
||
|
window_size (int): Window size of the unfold operation
|
||
|
height (int): Height of the feature map
|
||
|
width (int): Width of the feature map
|
||
|
|
||
|
Returns:
|
||
|
output (torch.Tensor): Folded output tensor of the shape (B, C, H, W)
|
||
|
"""
|
||
|
# Get channels of windows
|
||
|
channels: int = input.shape[1]
|
||
|
# Get original batch size
|
||
|
batch_size: int = int(input.shape[0] // (height * width // window_size // window_size))
|
||
|
# Reshape input to (B, C, H, W)
|
||
|
output: torch.Tensor = input.view(batch_size, height // window_size, width // window_size, channels,
|
||
|
window_size, window_size)
|
||
|
output: torch.Tensor = output.permute(0, 3, 1, 4, 2, 5).reshape(batch_size, channels, height, width)
|
||
|
return output
|
||
|
|
||
|
|
||
|
class WindowMultiHeadAttention(nn.Module):
|
||
|
r""" This class implements window-based Multi-Head-Attention with log-spaced continuous position bias.
|
||
|
|
||
|
Args:
|
||
|
in_features (int): Number of input features
|
||
|
window_size (int): Window size
|
||
|
number_of_heads (int): Number of attention heads
|
||
|
dropout_attention (float): Dropout rate of attention map
|
||
|
dropout_projection (float): Dropout rate after projection
|
||
|
meta_network_hidden_features (int): Number of hidden features in the two layer MLP meta network
|
||
|
sequential_self_attention (bool): If true sequential self-attention is performed
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
in_features: int,
|
||
|
window_size: int,
|
||
|
number_of_heads: int,
|
||
|
dropout_attention: float = 0.,
|
||
|
dropout_projection: float = 0.,
|
||
|
meta_network_hidden_features: int = 256,
|
||
|
sequential_self_attention: bool = False) -> None:
|
||
|
# Call super constructor
|
||
|
super(WindowMultiHeadAttention, self).__init__()
|
||
|
# Check parameter
|
||
|
assert (in_features % number_of_heads) == 0, \
|
||
|
"The number of input features (in_features) are not divisible by the number of heads (number_of_heads)."
|
||
|
# Save parameters
|
||
|
self.in_features: int = in_features
|
||
|
self.window_size: int = window_size
|
||
|
self.number_of_heads: int = number_of_heads
|
||
|
self.sequential_self_attention: bool = sequential_self_attention
|
||
|
# Init query, key and value mapping as a single layer
|
||
|
self.mapping_qkv: nn.Module = nn.Linear(in_features=in_features, out_features=in_features * 3, bias=True)
|
||
|
# Init attention dropout
|
||
|
self.attention_dropout: nn.Module = nn.Dropout(dropout_attention)
|
||
|
# Init projection mapping
|
||
|
self.projection: nn.Module = nn.Linear(in_features=in_features, out_features=in_features, bias=True)
|
||
|
# Init projection dropout
|
||
|
self.projection_dropout: nn.Module = nn.Dropout(dropout_projection)
|
||
|
# Init meta network for positional encodings
|
||
|
self.meta_network: nn.Module = nn.Sequential(
|
||
|
nn.Linear(in_features=2, out_features=meta_network_hidden_features, bias=True),
|
||
|
nn.ReLU(inplace=True),
|
||
|
nn.Linear(in_features=meta_network_hidden_features, out_features=number_of_heads, bias=True))
|
||
|
# Init tau
|
||
|
self.register_parameter("tau", torch.nn.Parameter(torch.ones(1, number_of_heads, 1, 1)))
|
||
|
# Init pair-wise relative positions (log-spaced)
|
||
|
self.__make_pair_wise_relative_positions()
|
||
|
|
||
|
def __make_pair_wise_relative_positions(self) -> None:
|
||
|
""" Method initializes the pair-wise relative positions to compute the positional biases."""
|
||
|
indexes: torch.Tensor = torch.arange(self.window_size, device=self.tau.device)
|
||
|
coordinates: torch.Tensor = torch.stack(torch.meshgrid([indexes, indexes]), dim=0)
|
||
|
coordinates: torch.Tensor = torch.flatten(coordinates, start_dim=1)
|
||
|
relative_coordinates: torch.Tensor = coordinates[:, :, None] - coordinates[:, None, :]
|
||
|
relative_coordinates: torch.Tensor = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float()
|
||
|
relative_coordinates_log: torch.Tensor = torch.sign(relative_coordinates) \
|
||
|
* torch.log(1. + relative_coordinates.abs())
|
||
|
self.register_buffer("relative_coordinates_log", relative_coordinates_log)
|
||
|
|
||
|
def update_resolution(self,
|
||
|
new_window_size: int,
|
||
|
**kwargs: Any) -> None:
|
||
|
""" Method updates the window size and so the pair-wise relative positions
|
||
|
|
||
|
Args:
|
||
|
new_window_size (int): New window size
|
||
|
kwargs (Any): Unused
|
||
|
"""
|
||
|
# Set new window size
|
||
|
self.window_size: int = new_window_size
|
||
|
# Make new pair-wise relative positions
|
||
|
self.__make_pair_wise_relative_positions()
|
||
|
|
||
|
def __get_relative_positional_encodings(self) -> torch.Tensor:
|
||
|
""" Method computes the relative positional encodings
|
||
|
|
||
|
Returns:
|
||
|
relative_position_bias (torch.Tensor): Relative positional encodings
|
||
|
(1, number of heads, window size ** 2, window size ** 2)
|
||
|
"""
|
||
|
relative_position_bias: torch.Tensor = self.meta_network(self.relative_coordinates_log)
|
||
|
relative_position_bias: torch.Tensor = relative_position_bias.permute(1, 0)
|
||
|
relative_position_bias: torch.Tensor = relative_position_bias.reshape(self.number_of_heads,
|
||
|
self.window_size * self.window_size,
|
||
|
self.window_size * self.window_size)
|
||
|
relative_position_bias: torch.Tensor = relative_position_bias.unsqueeze(0)
|
||
|
return relative_position_bias
|
||
|
|
||
|
def __self_attention(self,
|
||
|
query: torch.Tensor,
|
||
|
key: torch.Tensor,
|
||
|
value: torch.Tensor,
|
||
|
batch_size_windows: int,
|
||
|
tokens: int,
|
||
|
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||
|
""" This function performs standard (non-sequential) scaled cosine self-attention.
|
||
|
|
||
|
Args:
|
||
|
query (torch.Tensor): Query tensor of the shape [B * windows, heads, tokens, C // heads]
|
||
|
key (torch.Tensor): Key tensor of the shape [B * windows, heads, tokens, C // heads]
|
||
|
value (torch.Tensor): Value tensor of the shape (B * windows, heads, tokens, C // heads)
|
||
|
batch_size_windows (int): Size of the first dimension of the input tensor batch size * windows
|
||
|
tokens (int): Number of tokens in the input
|
||
|
mask (Optional[torch.Tensor]): Attention mask for the shift case
|
||
|
|
||
|
Returns:
|
||
|
output (torch.Tensor): Output feature map of the shape [B * windows, tokens, C]
|
||
|
"""
|
||
|
# Compute attention map with scaled cosine attention
|
||
|
attention_map: torch.Tensor = torch.einsum("bhqd, bhkd -> bhqk", query, key) \
|
||
|
/ torch.maximum(torch.norm(query, dim=-1, keepdim=True)
|
||
|
* torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1),
|
||
|
torch.tensor(1e-06, device=query.device, dtype=query.dtype))
|
||
|
attention_map: torch.Tensor = attention_map / self.tau.clamp(min=0.01)
|
||
|
# Apply relative positional encodings
|
||
|
attention_map: torch.Tensor = attention_map + self.__get_relative_positional_encodings()
|
||
|
# Apply mask if utilized
|
||
|
if mask is not None:
|
||
|
number_of_windows: int = mask.shape[0]
|
||
|
attention_map: torch.Tensor = attention_map.view(batch_size_windows // number_of_windows, number_of_windows,
|
||
|
self.number_of_heads, tokens, tokens)
|
||
|
attention_map: torch.Tensor = attention_map + mask.unsqueeze(1).unsqueeze(0)
|
||
|
attention_map: torch.Tensor = attention_map.view(-1, self.number_of_heads, tokens, tokens)
|
||
|
attention_map: torch.Tensor = attention_map.softmax(dim=-1)
|
||
|
# Perform attention dropout
|
||
|
attention_map: torch.Tensor = self.attention_dropout(attention_map)
|
||
|
# Apply attention map and reshape
|
||
|
output: torch.Tensor = torch.einsum("bhal, bhlv -> bhav", attention_map, value)
|
||
|
output: torch.Tensor = output.permute(0, 2, 1, 3).reshape(batch_size_windows, tokens, -1)
|
||
|
return output
|
||
|
|
||
|
def __sequential_self_attention(self,
|
||
|
query: torch.Tensor,
|
||
|
key: torch.Tensor,
|
||
|
value: torch.Tensor,
|
||
|
batch_size_windows: int,
|
||
|
tokens: int,
|
||
|
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||
|
""" This function performs sequential scaled cosine self-attention.
|
||
|
|
||
|
Args:
|
||
|
query (torch.Tensor): Query tensor of the shape [B * windows, heads, tokens, C // heads]
|
||
|
key (torch.Tensor): Key tensor of the shape [B * windows, heads, tokens, C // heads]
|
||
|
value (torch.Tensor): Value tensor of the shape (B * windows, heads, tokens, C // heads)
|
||
|
batch_size_windows (int): Size of the first dimension of the input tensor batch size * windows
|
||
|
tokens (int): Number of tokens in the input
|
||
|
mask (Optional[torch.Tensor]): Attention mask for the shift case
|
||
|
|
||
|
Returns:
|
||
|
output (torch.Tensor): Output feature map of the shape [B * windows, tokens, C]
|
||
|
"""
|
||
|
# Init output tensor
|
||
|
output: torch.Tensor = torch.ones_like(query)
|
||
|
# Compute relative positional encodings fist
|
||
|
relative_position_bias: torch.Tensor = self.__get_relative_positional_encodings()
|
||
|
# Iterate over query and key tokens
|
||
|
for token_index_query in range(tokens):
|
||
|
# Compute attention map with scaled cosine attention
|
||
|
attention_map: torch.Tensor = \
|
||
|
torch.einsum("bhd, bhkd -> bhk", query[:, :, token_index_query], key) \
|
||
|
/ torch.maximum(torch.norm(query[:, :, token_index_query], dim=-1, keepdim=True)
|
||
|
* torch.norm(key, dim=-1, keepdim=False),
|
||
|
torch.tensor(1e-06, device=query.device, dtype=query.dtype))
|
||
|
attention_map: torch.Tensor = attention_map / self.tau.clamp(min=0.01)[..., 0]
|
||
|
# Apply positional encodings
|
||
|
attention_map: torch.Tensor = attention_map + relative_position_bias[..., token_index_query, :]
|
||
|
# Apply mask if utilized
|
||
|
if mask is not None:
|
||
|
number_of_windows: int = mask.shape[0]
|
||
|
attention_map: torch.Tensor = attention_map.view(batch_size_windows // number_of_windows,
|
||
|
number_of_windows, self.number_of_heads, 1,
|
||
|
tokens)
|
||
|
attention_map: torch.Tensor = attention_map \
|
||
|
+ mask.unsqueeze(1).unsqueeze(0)[..., token_index_query, :].unsqueeze(3)
|
||
|
attention_map: torch.Tensor = attention_map.view(-1, self.number_of_heads, tokens)
|
||
|
attention_map: torch.Tensor = attention_map.softmax(dim=-1)
|
||
|
# Perform attention dropout
|
||
|
attention_map: torch.Tensor = self.attention_dropout(attention_map)
|
||
|
# Apply attention map and reshape
|
||
|
output[:, :, token_index_query] = torch.einsum("bhl, bhlv -> bhv", attention_map, value)
|
||
|
output: torch.Tensor = output.permute(0, 2, 1, 3).reshape(batch_size_windows, tokens, -1)
|
||
|
return output
|
||
|
|
||
|
def forward(self,
|
||
|
input: torch.Tensor,
|
||
|
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||
|
""" Forward pass.
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): Input tensor of the shape (B * windows, C, H, W)
|
||
|
mask (Optional[torch.Tensor]): Attention mask for the shift case
|
||
|
|
||
|
Returns:
|
||
|
output (torch.Tensor): Output tensor of the shape [B * windows, C, H, W]
|
||
|
"""
|
||
|
# Save original shape
|
||
|
batch_size_windows, channels, height, width = input.shape # type: int, int, int, int
|
||
|
tokens: int = height * width
|
||
|
# Reshape input to (B * windows, tokens (height * width), C)
|
||
|
input: torch.Tensor = input.reshape(batch_size_windows, channels, tokens).permute(0, 2, 1)
|
||
|
# Perform query, key, and value mapping
|
||
|
query_key_value: torch.Tensor = self.mapping_qkv(input)
|
||
|
query_key_value: torch.Tensor = query_key_value.view(batch_size_windows, tokens, 3, self.number_of_heads,
|
||
|
channels // self.number_of_heads).permute(2, 0, 3, 1, 4)
|
||
|
query, key, value = query_key_value[0], query_key_value[1], query_key_value[2]
|
||
|
# Perform attention
|
||
|
if self.sequential_self_attention:
|
||
|
output: torch.Tensor = self.__sequential_self_attention(query=query, key=key, value=value,
|
||
|
batch_size_windows=batch_size_windows,
|
||
|
tokens=tokens,
|
||
|
mask=mask)
|
||
|
else:
|
||
|
output: torch.Tensor = self.__self_attention(query=query, key=key, value=value,
|
||
|
batch_size_windows=batch_size_windows, tokens=tokens,
|
||
|
mask=mask)
|
||
|
# Perform linear mapping and dropout
|
||
|
output: torch.Tensor = self.projection_dropout(self.projection(output))
|
||
|
# Reshape output to original shape [B * windows, C, H, W]
|
||
|
output: torch.Tensor = output.permute(0, 2, 1).view(batch_size_windows, channels, height, width)
|
||
|
return output
|
||
|
|
||
|
|
||
|
class SwinTransformerBlock(nn.Module):
|
||
|
r""" This class implements the Swin transformer block.
|
||
|
|
||
|
Args:
|
||
|
in_channels (int): Number of input channels
|
||
|
input_resolution (Tuple[int, int]): Input resolution
|
||
|
number_of_heads (int): Number of attention heads to be utilized
|
||
|
window_size (int): Window size to be utilized
|
||
|
shift_size (int): Shifting size to be used
|
||
|
ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
|
||
|
dropout (float): Dropout in input mapping
|
||
|
dropout_attention (float): Dropout rate of attention map
|
||
|
dropout_path (float): Dropout in main path
|
||
|
sequential_self_attention (bool): If true sequential self-attention is performed
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
in_channels: int,
|
||
|
input_resolution: Tuple[int, int],
|
||
|
number_of_heads: int,
|
||
|
window_size: int = 7,
|
||
|
shift_size: int = 0,
|
||
|
ff_feature_ratio: int = 4,
|
||
|
dropout: float = 0.0,
|
||
|
dropout_attention: float = 0.0,
|
||
|
dropout_path: float = 0.0,
|
||
|
sequential_self_attention: bool = False) -> None:
|
||
|
# Call super constructor
|
||
|
super(SwinTransformerBlock, self).__init__()
|
||
|
# Save parameters
|
||
|
self.in_channels: int = in_channels
|
||
|
self.input_resolution: Tuple[int, int] = input_resolution
|
||
|
# Catch case if resolution is smaller than the window size
|
||
|
if min(self.input_resolution) <= window_size:
|
||
|
self.window_size: int = min(self.input_resolution)
|
||
|
self.shift_size: int = 0
|
||
|
self.make_windows: bool = False
|
||
|
else:
|
||
|
self.window_size: int = window_size
|
||
|
self.shift_size: int = shift_size
|
||
|
self.make_windows: bool = True
|
||
|
# Init normalization layers
|
||
|
self.normalization_1: nn.Module = nn.LayerNorm(normalized_shape=in_channels)
|
||
|
self.normalization_2: nn.Module = nn.LayerNorm(normalized_shape=in_channels)
|
||
|
# Init window attention module
|
||
|
self.window_attention: WindowMultiHeadAttention = WindowMultiHeadAttention(
|
||
|
in_features=in_channels,
|
||
|
window_size=self.window_size,
|
||
|
number_of_heads=number_of_heads,
|
||
|
dropout_attention=dropout_attention,
|
||
|
dropout_projection=dropout,
|
||
|
sequential_self_attention=sequential_self_attention)
|
||
|
# Init dropout layer
|
||
|
self.dropout: nn.Module = DropPath(drop_prob=dropout_path) if dropout_path > 0. else nn.Identity()
|
||
|
# Init feed-forward network
|
||
|
self.feed_forward_network: nn.Module = Mlp(in_features=in_channels,
|
||
|
hidden_features=int(in_channels * ff_feature_ratio),
|
||
|
drop=dropout,
|
||
|
out_features=in_channels)
|
||
|
# Make attention mask
|
||
|
self.__make_attention_mask()
|
||
|
|
||
|
def __make_attention_mask(self) -> None:
|
||
|
""" Method generates the attention mask used in shift case. """
|
||
|
# Make masks for shift case
|
||
|
if self.shift_size > 0:
|
||
|
height, width = self.input_resolution # type: int, int
|
||
|
mask: torch.Tensor = torch.zeros(height, width, device=self.window_attention.tau.device)
|
||
|
height_slices: Tuple = (slice(0, -self.window_size),
|
||
|
slice(-self.window_size, -self.shift_size),
|
||
|
slice(-self.shift_size, None))
|
||
|
width_slices: Tuple = (slice(0, -self.window_size),
|
||
|
slice(-self.window_size, -self.shift_size),
|
||
|
slice(-self.shift_size, None))
|
||
|
counter: int = 0
|
||
|
for height_slice in height_slices:
|
||
|
for width_slice in width_slices:
|
||
|
mask[height_slice, width_slice] = counter
|
||
|
counter += 1
|
||
|
mask_windows: torch.Tensor = unfold(mask[None, None], self.window_size)
|
||
|
mask_windows: torch.Tensor = mask_windows.reshape(-1, self.window_size * self.window_size)
|
||
|
attention_mask: Optional[torch.Tensor] = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||
|
attention_mask: Optional[torch.Tensor] = attention_mask.masked_fill(attention_mask != 0, float(-100.0))
|
||
|
attention_mask: Optional[torch.Tensor] = attention_mask.masked_fill(attention_mask == 0, float(0.0))
|
||
|
else:
|
||
|
attention_mask: Optional[torch.Tensor] = None
|
||
|
# Save mask
|
||
|
self.register_buffer("attention_mask", attention_mask)
|
||
|
|
||
|
def update_resolution(self,
|
||
|
new_window_size: int,
|
||
|
new_input_resolution: Tuple[int, int]) -> None:
|
||
|
""" Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
|
||
|
|
||
|
Args:
|
||
|
new_window_size (int): New window size
|
||
|
new_input_resolution (Tuple[int, int]): New input resolution
|
||
|
"""
|
||
|
# Update input resolution
|
||
|
self.input_resolution: Tuple[int, int] = new_input_resolution
|
||
|
# Catch case if resolution is smaller than the window size
|
||
|
if min(self.input_resolution) <= new_window_size:
|
||
|
self.window_size: int = min(self.input_resolution)
|
||
|
self.shift_size: int = 0
|
||
|
self.make_windows: bool = False
|
||
|
else:
|
||
|
self.window_size: int = new_window_size
|
||
|
self.shift_size: int = self.shift_size
|
||
|
self.make_windows: bool = True
|
||
|
# Update attention mask
|
||
|
self.__make_attention_mask()
|
||
|
# Update attention module
|
||
|
self.window_attention.update_resolution(new_window_size=new_window_size)
|
||
|
|
||
|
def forward(self,
|
||
|
input: torch.Tensor) -> torch.Tensor:
|
||
|
""" Forward pass.
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): Input tensor of the shape [B, C, H, W]
|
||
|
|
||
|
Returns:
|
||
|
output (torch.Tensor): Output tensor of the shape [B, C, H, W]
|
||
|
"""
|
||
|
# Save shape
|
||
|
batch_size, channels, height, width = input.shape # type: int, int, int, int
|
||
|
# Shift input if utilized
|
||
|
if self.shift_size > 0:
|
||
|
output_shift: torch.Tensor = torch.roll(input=input, shifts=(-self.shift_size, -self.shift_size),
|
||
|
dims=(-1, -2))
|
||
|
else:
|
||
|
output_shift: torch.Tensor = input
|
||
|
# Make patches
|
||
|
output_patches: torch.Tensor = unfold(input=output_shift, window_size=self.window_size) \
|
||
|
if self.make_windows else output_shift
|
||
|
# Perform window attention
|
||
|
output_attention: torch.Tensor = self.window_attention(output_patches, mask=self.attention_mask)
|
||
|
# Merge patches
|
||
|
output_merge: torch.Tensor = fold(input=output_attention, window_size=self.window_size, height=height,
|
||
|
width=width) if self.make_windows else output_attention
|
||
|
# Reverse shift if utilized
|
||
|
if self.shift_size > 0:
|
||
|
output_shift: torch.Tensor = torch.roll(input=output_merge, shifts=(self.shift_size, self.shift_size),
|
||
|
dims=(-1, -2))
|
||
|
else:
|
||
|
output_shift: torch.Tensor = output_merge
|
||
|
# Perform normalization
|
||
|
output_normalize: torch.Tensor = self.normalization_1(output_shift.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||
|
# Skip connection
|
||
|
output_skip: torch.Tensor = self.dropout(output_normalize) + input
|
||
|
# Feed forward network, normalization and skip connection
|
||
|
output_feed_forward: torch.Tensor = self.feed_forward_network(
|
||
|
output_skip.view(batch_size, channels, -1).permute(0, 2, 1)).permute(0, 2, 1)
|
||
|
output_feed_forward: torch.Tensor = output_feed_forward.view(batch_size, channels, height, width)
|
||
|
output_normalize: torch.Tensor = bhwc_to_bchw(self.normalization_2(bchw_to_bhwc(output_feed_forward)))
|
||
|
output: torch.Tensor = output_skip + self.dropout(output_normalize)
|
||
|
return output
|
||
|
|
||
|
|
||
|
class DeformableSwinTransformerBlock(SwinTransformerBlock):
|
||
|
r""" This class implements a deformable version of the Swin Transformer block.
|
||
|
Inspired by: https://arxiv.org/pdf/2201.00520
|
||
|
|
||
|
Args:
|
||
|
in_channels (int): Number of input channels
|
||
|
input_resolution (Tuple[int, int]): Input resolution
|
||
|
number_of_heads (int): Number of attention heads to be utilized
|
||
|
window_size (int): Window size to be utilized
|
||
|
shift_size (int): Shifting size to be used
|
||
|
ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
|
||
|
dropout (float): Dropout in input mapping
|
||
|
dropout_attention (float): Dropout rate of attention map
|
||
|
dropout_path (float): Dropout in main path
|
||
|
sequential_self_attention (bool): If true sequential self-attention is performed
|
||
|
offset_downscale_factor (int): Downscale factor of offset network
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
in_channels: int,
|
||
|
input_resolution: Tuple[int, int],
|
||
|
number_of_heads: int,
|
||
|
window_size: int = 7,
|
||
|
shift_size: int = 0,
|
||
|
ff_feature_ratio: int = 4,
|
||
|
dropout: float = 0.0,
|
||
|
dropout_attention: float = 0.0,
|
||
|
dropout_path: float = 0.0,
|
||
|
sequential_self_attention: bool = False,
|
||
|
offset_downscale_factor: int = 2) -> None:
|
||
|
# Call super constructor
|
||
|
super(DeformableSwinTransformerBlock, self).__init__(
|
||
|
in_channels=in_channels,
|
||
|
input_resolution=input_resolution,
|
||
|
number_of_heads=number_of_heads,
|
||
|
window_size=window_size,
|
||
|
shift_size=shift_size,
|
||
|
ff_feature_ratio=ff_feature_ratio,
|
||
|
dropout=dropout,
|
||
|
dropout_attention=dropout_attention,
|
||
|
dropout_path=dropout_path,
|
||
|
sequential_self_attention=sequential_self_attention
|
||
|
)
|
||
|
# Save parameter
|
||
|
self.offset_downscale_factor: int = offset_downscale_factor
|
||
|
self.number_of_heads: int = number_of_heads
|
||
|
# Make default offsets
|
||
|
self.__make_default_offsets()
|
||
|
# Init offset network
|
||
|
self.offset_network: nn.Module = nn.Sequential(
|
||
|
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=5, stride=offset_downscale_factor,
|
||
|
padding=3, groups=in_channels, bias=True),
|
||
|
nn.GELU(),
|
||
|
nn.Conv2d(in_channels=in_channels, out_channels=2 * self.number_of_heads, kernel_size=1, stride=1,
|
||
|
padding=0, bias=True)
|
||
|
)
|
||
|
|
||
|
def __make_default_offsets(self) -> None:
|
||
|
""" Method generates the default sampling grid (inspired by kornia) """
|
||
|
# Init x and y coordinates
|
||
|
x: torch.Tensor = torch.linspace(0, self.input_resolution[1] - 1, self.input_resolution[1],
|
||
|
device=self.window_attention.tau.device)
|
||
|
y: torch.Tensor = torch.linspace(0, self.input_resolution[0] - 1, self.input_resolution[0],
|
||
|
device=self.window_attention.tau.device)
|
||
|
# Normalize coordinates to a range of [-1, 1]
|
||
|
x: torch.Tensor = (x / (self.input_resolution[1] - 1) - 0.5) * 2
|
||
|
y: torch.Tensor = (y / (self.input_resolution[0] - 1) - 0.5) * 2
|
||
|
# Make grid [2, height, width]
|
||
|
grid: torch.Tensor = torch.stack(torch.meshgrid([x, y])).transpose(1, 2)
|
||
|
# Reshape grid to [1, height, width, 2]
|
||
|
grid: torch.Tensor = grid.unsqueeze(dim=0).permute(0, 2, 3, 1)
|
||
|
# Register in module
|
||
|
self.register_buffer("default_grid", grid)
|
||
|
|
||
|
def update_resolution(self,
|
||
|
new_window_size: int,
|
||
|
new_input_resolution: Tuple[int, int]) -> None:
|
||
|
""" Method updates the window size and so the pair-wise relative positions.
|
||
|
|
||
|
Args:
|
||
|
new_window_size (int): New window size
|
||
|
new_input_resolution (Tuple[int, int]): New input resolution
|
||
|
"""
|
||
|
# Update resolution and window size
|
||
|
super(DeformableSwinTransformerBlock, self).update_resolution(new_window_size=new_window_size,
|
||
|
new_input_resolution=new_input_resolution)
|
||
|
# Update default sampling grid
|
||
|
self.__make_default_offsets()
|
||
|
|
||
|
def forward(self,
|
||
|
input: torch.Tensor) -> torch.Tensor:
|
||
|
""" Forward pass
|
||
|
Args:
|
||
|
input (torch.Tensor): Input tensor of the shape [B, C, H, W]
|
||
|
|
||
|
Returns:
|
||
|
output (torch.Tensor): Output tensor of the shape [B, C, H, W]
|
||
|
"""
|
||
|
# Get input shape
|
||
|
batch_size, channels, height, width = input.shape
|
||
|
# Compute offsets of the shape [batch size, 2, height / r, width / r]
|
||
|
offsets: torch.Tensor = self.offset_network(input)
|
||
|
# Upscale offsets to the shape [batch size, 2 * number of heads, height, width]
|
||
|
offsets: torch.Tensor = F.interpolate(input=offsets,
|
||
|
size=(height, width), mode="bilinear", align_corners=True)
|
||
|
# Reshape offsets to [batch size, number of heads, height, width, 2]
|
||
|
offsets: torch.Tensor = offsets.reshape(batch_size, -1, 2, height, width).permute(0, 1, 3, 4, 2)
|
||
|
# Flatten batch size and number of heads and apply tanh
|
||
|
offsets: torch.Tensor = offsets.view(-1, height, width, 2).tanh()
|
||
|
# Cast offset grid to input data type
|
||
|
if input.dtype != self.default_grid.dtype:
|
||
|
self.default_grid = self.default_grid.type(input.dtype)
|
||
|
# Construct offset grid
|
||
|
offset_grid: torch.Tensor = self.default_grid.repeat_interleave(repeats=offsets.shape[0], dim=0) + offsets
|
||
|
# Reshape input to [batch size * number of heads, channels / number of heads, height, width]
|
||
|
input: torch.Tensor = input.view(batch_size, self.number_of_heads, channels // self.number_of_heads, height,
|
||
|
width).flatten(start_dim=0, end_dim=1)
|
||
|
# Apply sampling grid
|
||
|
input_resampled: torch.Tensor = F.grid_sample(input=input, grid=offset_grid.clip(min=-1, max=1),
|
||
|
mode="bilinear", align_corners=True, padding_mode="reflection")
|
||
|
# Reshape resampled tensor again to [batch size, channels, height, width]
|
||
|
input_resampled: torch.Tensor = input_resampled.view(batch_size, channels, height, width)
|
||
|
output: torch.Tensor = super(DeformableSwinTransformerBlock, self).forward(input=input_resampled)
|
||
|
return output
|
||
|
|
||
|
|
||
|
class PatchMerging(nn.Module):
|
||
|
""" This class implements the patch merging as a strided convolution with a normalization before.
|
||
|
|
||
|
Args:
|
||
|
in_channels (int): Number of input channels
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
in_channels: int) -> None:
|
||
|
# Call super constructor
|
||
|
super(PatchMerging, self).__init__()
|
||
|
# Init normalization
|
||
|
self.normalization: nn.Module = nn.LayerNorm(normalized_shape=4 * in_channels)
|
||
|
# Init linear mapping
|
||
|
self.linear_mapping: nn.Module = nn.Linear(in_features=4 * in_channels, out_features=2 * in_channels,
|
||
|
bias=False)
|
||
|
|
||
|
def forward(self,
|
||
|
input: torch.Tensor) -> torch.Tensor:
|
||
|
""" Forward pass.
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): Input tensor of the shape [B, C, H, W]
|
||
|
|
||
|
Returns:
|
||
|
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
|
||
|
"""
|
||
|
# Get original shape
|
||
|
batch_size, channels, height, width = input.shape # type: int, int, int, int
|
||
|
# Reshape input to [batch size, in channels, height, width]
|
||
|
input: torch.Tensor = bchw_to_bhwc(input)
|
||
|
# Unfold input
|
||
|
input: torch.Tensor = input.unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2)
|
||
|
input: torch.Tensor = input.reshape(batch_size, input.shape[1], input.shape[2], -1)
|
||
|
# Normalize input
|
||
|
input: torch.Tensor = self.normalization(input)
|
||
|
# Perform linear mapping
|
||
|
output: torch.Tensor = bhwc_to_bchw(self.linear_mapping(input))
|
||
|
return output
|
||
|
|
||
|
|
||
|
class PatchEmbedding(nn.Module):
|
||
|
""" Module embeds a given image into patch embeddings.
|
||
|
|
||
|
Args:
|
||
|
in_channels (int): Number of input channels
|
||
|
out_channels (int): Number of output channels
|
||
|
patch_size (int): Patch size to be utilized
|
||
|
image_size (int): Image size to be used
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
in_channels: int = 3,
|
||
|
out_channels: int = 96,
|
||
|
patch_size: int = 4) -> None:
|
||
|
# Call super constructor
|
||
|
super(PatchEmbedding, self).__init__()
|
||
|
# Save parameters
|
||
|
self.out_channels: int = out_channels
|
||
|
# Init linear embedding as a convolution
|
||
|
self.linear_embedding: nn.Module = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
|
||
|
kernel_size=(patch_size, patch_size),
|
||
|
stride=(patch_size, patch_size))
|
||
|
# Init layer normalization
|
||
|
self.normalization: nn.Module = nn.LayerNorm(normalized_shape=out_channels)
|
||
|
|
||
|
def forward(self,
|
||
|
input: torch.Tensor) -> torch.Tensor:
|
||
|
""" Forward pass.
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): Input image of the shape (B, C_in, H, W)
|
||
|
|
||
|
Returns:
|
||
|
embedding (torch.Tensor): Embedding of the shape (B, C_out, H / patch size, W / patch size)
|
||
|
"""
|
||
|
# Perform linear embedding
|
||
|
embedding: torch.Tensor = self.linear_embedding(input)
|
||
|
# Perform normalization
|
||
|
embedding: torch.Tensor = bhwc_to_bchw(self.normalization(bchw_to_bhwc(embedding)))
|
||
|
return embedding
|
||
|
|
||
|
|
||
|
class SwinTransformerStage(nn.Module):
|
||
|
r""" This class implements a stage of the Swin transformer including multiple layers.
|
||
|
|
||
|
Args:
|
||
|
in_channels (int): Number of input channels
|
||
|
depth (int): Depth of the stage (number of layers)
|
||
|
downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper)
|
||
|
input_resolution (Tuple[int, int]): Input resolution
|
||
|
number_of_heads (int): Number of attention heads to be utilized
|
||
|
window_size (int): Window size to be utilized
|
||
|
shift_size (int): Shifting size to be used
|
||
|
ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
|
||
|
dropout (float): Dropout in input mapping
|
||
|
dropout_attention (float): Dropout rate of attention map
|
||
|
dropout_path (float): Dropout in main path
|
||
|
use_checkpoint (bool): If true checkpointing is utilized
|
||
|
sequential_self_attention (bool): If true sequential self-attention is performed
|
||
|
use_deformable_block (bool): If true deformable block is used
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
in_channels: int,
|
||
|
depth: int,
|
||
|
downscale: bool,
|
||
|
input_resolution: Tuple[int, int],
|
||
|
number_of_heads: int,
|
||
|
window_size: int = 7,
|
||
|
ff_feature_ratio: int = 4,
|
||
|
dropout: float = 0.0,
|
||
|
dropout_attention: float = 0.0,
|
||
|
dropout_path: Union[List[float], float] = 0.0,
|
||
|
use_checkpoint: bool = False,
|
||
|
sequential_self_attention: bool = False,
|
||
|
use_deformable_block: bool = False) -> None:
|
||
|
# Call super constructor
|
||
|
super(SwinTransformerStage, self).__init__()
|
||
|
# Save parameters
|
||
|
self.use_checkpoint: bool = use_checkpoint
|
||
|
self.downscale: bool = downscale
|
||
|
# Init downsampling
|
||
|
self.downsample: nn.Module = PatchMerging(in_channels=in_channels) if downscale else nn.Identity()
|
||
|
# Update resolution and channels
|
||
|
self.input_resolution: Tuple[int, int] = (input_resolution[0] // 2, input_resolution[1] // 2) \
|
||
|
if downscale else input_resolution
|
||
|
in_channels = in_channels * 2 if downscale else in_channels
|
||
|
# Get block
|
||
|
block = DeformableSwinTransformerBlock if use_deformable_block else SwinTransformerBlock
|
||
|
# Init blocks
|
||
|
self.blocks: nn.ModuleList = nn.ModuleList([
|
||
|
block(in_channels=in_channels,
|
||
|
input_resolution=self.input_resolution,
|
||
|
number_of_heads=number_of_heads,
|
||
|
window_size=window_size,
|
||
|
shift_size=0 if ((index % 2) == 0) else window_size // 2,
|
||
|
ff_feature_ratio=ff_feature_ratio,
|
||
|
dropout=dropout,
|
||
|
dropout_attention=dropout_attention,
|
||
|
dropout_path=dropout_path[index] if isinstance(dropout_path, list) else dropout_path,
|
||
|
sequential_self_attention=sequential_self_attention)
|
||
|
for index in range(depth)])
|
||
|
|
||
|
def update_resolution(self,
|
||
|
new_window_size: int,
|
||
|
new_input_resolution: Tuple[int, int]) -> None:
|
||
|
""" Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
|
||
|
|
||
|
Args:
|
||
|
new_window_size (int): New window size
|
||
|
new_input_resolution (Tuple[int, int]): New input resolution
|
||
|
"""
|
||
|
# Update resolution
|
||
|
self.input_resolution: Tuple[int, int] = (new_input_resolution[0] // 2, new_input_resolution[1] // 2) \
|
||
|
if self.downscale else new_input_resolution
|
||
|
# Update resolution of each block
|
||
|
for block in self.blocks: # type: SwinTransformerBlock
|
||
|
block.update_resolution(new_window_size=new_window_size, new_input_resolution=self.input_resolution)
|
||
|
|
||
|
def forward(self,
|
||
|
input: torch.Tensor) -> torch.Tensor:
|
||
|
""" Forward pass.
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): Input tensor of the shape [B, C, H, W]
|
||
|
|
||
|
Returns:
|
||
|
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
|
||
|
"""
|
||
|
# Downscale input tensor
|
||
|
output: torch.Tensor = self.downsample(input)
|
||
|
# Forward pass of each block
|
||
|
for block in self.blocks: # type: nn.Module
|
||
|
# Perform checkpointing if utilized
|
||
|
if self.use_checkpoint:
|
||
|
output: torch.Tensor = checkpoint.checkpoint(block, output)
|
||
|
else:
|
||
|
output: torch.Tensor = block(output)
|
||
|
return output
|
||
|
|
||
|
|
||
|
class SwinTransformerV2(nn.Module):
|
||
|
r""" Swin Transformer V2
|
||
|
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` -
|
||
|
https://arxiv.org/pdf/2111.09883
|
||
|
|
||
|
Args:
|
||
|
in_channels (int): Number of input channels
|
||
|
depth (int): Depth of the stage (number of layers)
|
||
|
downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper)
|
||
|
input_resolution (Tuple[int, int]): Input resolution
|
||
|
number_of_heads (int): Number of attention heads to be utilized
|
||
|
num_classes (int): Number of output classes
|
||
|
window_size (int): Window size to be utilized
|
||
|
shift_size (int): Shifting size to be used
|
||
|
ff_feature_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
|
||
|
dropout (float): Dropout in input mapping
|
||
|
dropout_attention (float): Dropout rate of attention map
|
||
|
dropout_path (float): Dropout in main path
|
||
|
use_checkpoint (bool): If true checkpointing is utilized
|
||
|
sequential_self_attention (bool): If true sequential self-attention is performed
|
||
|
use_deformable_block (bool): If true deformable block is used
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
in_channels: int,
|
||
|
embedding_channels: int,
|
||
|
depths: Tuple[int, ...],
|
||
|
input_resolution: Tuple[int, int],
|
||
|
number_of_heads: Tuple[int, ...],
|
||
|
num_classes: int = 1000,
|
||
|
window_size: int = 7,
|
||
|
patch_size: int = 4,
|
||
|
ff_feature_ratio: int = 4,
|
||
|
dropout: float = 0.0,
|
||
|
dropout_attention: float = 0.0,
|
||
|
dropout_path: float = 0.2,
|
||
|
use_checkpoint: bool = False,
|
||
|
sequential_self_attention: bool = False,
|
||
|
use_deformable_block: bool = False) -> None:
|
||
|
# Call super constructor
|
||
|
super(SwinTransformerV2, self).__init__()
|
||
|
# Save parameters
|
||
|
self.patch_size: int = patch_size
|
||
|
self.input_resolution: Tuple[int, int] = input_resolution
|
||
|
self.window_size: int = window_size
|
||
|
# Init patch embedding
|
||
|
self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_channels, out_channels=embedding_channels,
|
||
|
patch_size=patch_size)
|
||
|
# Compute patch resolution
|
||
|
patch_resolution: Tuple[int, int] = (input_resolution[0] // patch_size, input_resolution[1] // patch_size)
|
||
|
# Path dropout dependent on depth
|
||
|
dropout_path = torch.linspace(0., dropout_path, sum(depths)).tolist()
|
||
|
# Init stages
|
||
|
self.stages: nn.ModuleList = nn.ModuleList()
|
||
|
for index, (depth, number_of_head) in enumerate(zip(depths, number_of_heads)):
|
||
|
self.stages.append(
|
||
|
SwinTransformerStage(
|
||
|
in_channels=embedding_channels * (2 ** max(index - 1, 0)),
|
||
|
depth=depth,
|
||
|
downscale=index != 0,
|
||
|
input_resolution=(patch_resolution[0] // (2 ** max(index - 1, 0)),
|
||
|
patch_resolution[1] // (2 ** max(index - 1, 0))),
|
||
|
number_of_heads=number_of_head,
|
||
|
window_size=window_size,
|
||
|
ff_feature_ratio=ff_feature_ratio,
|
||
|
dropout=dropout,
|
||
|
dropout_attention=dropout_attention,
|
||
|
dropout_path=dropout_path[sum(depths[:index]):sum(depths[:index + 1])],
|
||
|
use_checkpoint=use_checkpoint,
|
||
|
sequential_self_attention=sequential_self_attention,
|
||
|
use_deformable_block=use_deformable_block and (index > 0)
|
||
|
))
|
||
|
# Init final adaptive average pooling, and classification head
|
||
|
self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1)
|
||
|
self.head: nn.Module = nn.Linear(in_features=embedding_channels * (2 ** len(depths) - 1),
|
||
|
out_features=num_classes)
|
||
|
|
||
|
def update_resolution(self,
|
||
|
new_input_resolution: Optional[Tuple[int, int]] = None,
|
||
|
new_window_size: Optional[int] = None) -> None:
|
||
|
""" Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
|
||
|
|
||
|
Args:
|
||
|
new_window_size (Optional[int]): New window size if None current window size is used
|
||
|
new_input_resolution (Optional[Tuple[int, int]]): New input resolution if None current resolution is used
|
||
|
"""
|
||
|
# Check parameters
|
||
|
if new_input_resolution is None:
|
||
|
new_input_resolution = self.input_resolution
|
||
|
if new_window_size is None:
|
||
|
new_window_size = self.window_size
|
||
|
# Compute new patch resolution
|
||
|
new_patch_resolution: Tuple[int, int] = (new_input_resolution[0] // self.patch_size,
|
||
|
new_input_resolution[1] // self.patch_size)
|
||
|
# Update resolution of each stage
|
||
|
for index, stage in enumerate(self.stages): # type: int, SwinTransformerStage
|
||
|
stage.update_resolution(new_window_size=new_window_size,
|
||
|
new_input_resolution=(new_patch_resolution[0] // (2 ** max(index - 1, 0)),
|
||
|
new_patch_resolution[1] // (2 ** max(index - 1, 0))))
|
||
|
|
||
|
def forward_features(self,
|
||
|
input: torch.Tensor) -> List[torch.Tensor]:
|
||
|
""" Forward pass to extract feature maps of each stage.
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): Input images of the shape (B, C, H, W)
|
||
|
|
||
|
Returns:
|
||
|
features (List[torch.Tensor]): List of feature maps from each stage
|
||
|
"""
|
||
|
# Perform patch embedding
|
||
|
output: torch.Tensor = self.patch_embedding(input)
|
||
|
# Init list to store feature
|
||
|
features: List[torch.Tensor] = []
|
||
|
# Forward pass of each stage
|
||
|
for stage in self.stages:
|
||
|
output: torch.Tensor = stage(output)
|
||
|
features.append(output)
|
||
|
return features
|
||
|
|
||
|
def forward(self,
|
||
|
input: torch.Tensor) -> torch.Tensor:
|
||
|
""" Forward pass.
|
||
|
|
||
|
Args:
|
||
|
input (torch.Tensor): Input images of the shape (B, C, H, W)
|
||
|
|
||
|
Returns:
|
||
|
classification (torch.Tensor): Classification of the shape (B, num_classes)
|
||
|
"""
|
||
|
# Perform patch embedding
|
||
|
output: torch.Tensor = self.patch_embedding(input)
|
||
|
# Forward pass of each stage
|
||
|
for stage in self.stages:
|
||
|
output: torch.Tensor = stage(output)
|
||
|
# Perform average pooling
|
||
|
output: torch.Tensor = self.average_pool(output)
|
||
|
# Predict classification
|
||
|
classification: torch.Tensor = self.head(output)
|
||
|
return classification
|