commit
b049a5c5c6
@ -0,0 +1,156 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch import nn, Tensor
|
||||||
|
from torch.nn.modules.transformer import _get_activation_fn
|
||||||
|
|
||||||
|
|
||||||
|
def add_ml_decoder_head(model):
|
||||||
|
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50
|
||||||
|
model.global_pool = nn.Identity()
|
||||||
|
del model.fc
|
||||||
|
num_classes = model.num_classes
|
||||||
|
num_features = model.num_features
|
||||||
|
model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
||||||
|
elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet
|
||||||
|
model.global_pool = nn.Identity()
|
||||||
|
del model.classifier
|
||||||
|
num_classes = model.num_classes
|
||||||
|
num_features = model.num_features
|
||||||
|
model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
||||||
|
elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')
|
||||||
|
del model.head
|
||||||
|
num_classes = model.num_classes
|
||||||
|
num_features = model.num_features
|
||||||
|
model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
||||||
|
else:
|
||||||
|
print("Model code-writing is not aligned currently with ml-decoder")
|
||||||
|
exit(-1)
|
||||||
|
if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout
|
||||||
|
model.drop_rate = 0
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderLayerOptimal(nn.Module):
|
||||||
|
def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu",
|
||||||
|
layer_norm_eps=1e-5) -> None:
|
||||||
|
super(TransformerDecoderLayerOptimal, self).__init__()
|
||||||
|
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
self.dropout3 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||||
|
|
||||||
|
# Implementation of Feedforward model
|
||||||
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||||
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||||
|
|
||||||
|
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||||
|
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
||||||
|
|
||||||
|
self.activation = _get_activation_fn(activation)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
if 'activation' not in state:
|
||||||
|
state['activation'] = torch.nn.functional.relu
|
||||||
|
super(TransformerDecoderLayerOptimal, self).__setstate__(state)
|
||||||
|
|
||||||
|
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None,
|
||||||
|
memory_mask: Optional[Tensor] = None,
|
||||||
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
||||||
|
tgt = tgt + self.dropout1(tgt)
|
||||||
|
tgt = self.norm1(tgt)
|
||||||
|
tgt2 = self.multihead_attn(tgt, memory, memory)[0]
|
||||||
|
tgt = tgt + self.dropout2(tgt2)
|
||||||
|
tgt = self.norm2(tgt)
|
||||||
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||||
|
tgt = tgt + self.dropout3(tgt2)
|
||||||
|
tgt = self.norm3(tgt)
|
||||||
|
return tgt
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.jit.script
|
||||||
|
# class ExtrapClasses(object):
|
||||||
|
# def __init__(self, num_queries: int, group_size: int):
|
||||||
|
# self.num_queries = num_queries
|
||||||
|
# self.group_size = group_size
|
||||||
|
#
|
||||||
|
# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap:
|
||||||
|
# torch.Tensor):
|
||||||
|
# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size)
|
||||||
|
# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups])
|
||||||
|
# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size))
|
||||||
|
# out = (h * w).sum(dim=2) + class_embed_b
|
||||||
|
# out = out.view((h.shape[0], self.group_size * self.num_queries))
|
||||||
|
# return out
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
class GroupFC(object):
|
||||||
|
def __init__(self, embed_len_decoder: int):
|
||||||
|
self.embed_len_decoder = embed_len_decoder
|
||||||
|
|
||||||
|
def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor):
|
||||||
|
for i in range(self.embed_len_decoder):
|
||||||
|
h_i = h[:, i, :]
|
||||||
|
w_i = duplicate_pooling[i, :, :]
|
||||||
|
out_extrap[:, i, :] = torch.matmul(h_i, w_i)
|
||||||
|
|
||||||
|
|
||||||
|
class MLDecoder(nn.Module):
|
||||||
|
def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048):
|
||||||
|
super(MLDecoder, self).__init__()
|
||||||
|
embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups
|
||||||
|
if embed_len_decoder > num_classes:
|
||||||
|
embed_len_decoder = num_classes
|
||||||
|
|
||||||
|
# switching to 768 initial embeddings
|
||||||
|
decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding
|
||||||
|
self.embed_standart = nn.Linear(initial_num_features, decoder_embedding)
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
decoder_dropout = 0.1
|
||||||
|
num_layers_decoder = 1
|
||||||
|
dim_feedforward = 2048
|
||||||
|
layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding,
|
||||||
|
dim_feedforward=dim_feedforward, dropout=decoder_dropout)
|
||||||
|
self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder)
|
||||||
|
|
||||||
|
# non-learnable queries
|
||||||
|
self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding)
|
||||||
|
self.query_embed.requires_grad_(False)
|
||||||
|
|
||||||
|
# group fully-connected
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999)
|
||||||
|
self.duplicate_pooling = torch.nn.Parameter(
|
||||||
|
torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor))
|
||||||
|
self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes))
|
||||||
|
torch.nn.init.xavier_normal_(self.duplicate_pooling)
|
||||||
|
torch.nn.init.constant_(self.duplicate_pooling_bias, 0)
|
||||||
|
self.group_fc = GroupFC(embed_len_decoder)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if len(x.shape) == 4: # [bs,2048, 7,7]
|
||||||
|
embedding_spatial = x.flatten(2).transpose(1, 2)
|
||||||
|
else: # [bs, 197,468]
|
||||||
|
embedding_spatial = x
|
||||||
|
embedding_spatial_786 = self.embed_standart(embedding_spatial)
|
||||||
|
embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)
|
||||||
|
|
||||||
|
bs = embedding_spatial_786.shape[0]
|
||||||
|
query_embed = self.query_embed.weight
|
||||||
|
# tgt = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||||
|
tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand
|
||||||
|
h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768]
|
||||||
|
h = h.transpose(0, 1)
|
||||||
|
|
||||||
|
out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype)
|
||||||
|
self.group_fc(h, self.duplicate_pooling, out_extrap)
|
||||||
|
h_out = out_extrap.flatten(1)[:, :self.num_classes]
|
||||||
|
h_out += self.duplicate_pooling_bias
|
||||||
|
logits = h_out
|
||||||
|
return logits
|
@ -0,0 +1,931 @@
|
|||||||
|
""" Swin Transformer V2
|
||||||
|
|
||||||
|
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
|
||||||
|
- https://arxiv.org/pdf/2111.09883
|
||||||
|
|
||||||
|
Code adapted from https://github.com/ChristophReich1996/Swin-Transformer-V2, original copyright/license info below
|
||||||
|
|
||||||
|
This implementation is experimental and subject to change in manners that will break weight compat:
|
||||||
|
* Size of the pos embed MLP are not spelled out in paper in terms of dim, fixed for all models? vary with num_heads?
|
||||||
|
* currently dim is fixed, I feel it may make sense to scale with num_heads (dim per head)
|
||||||
|
* The specifics of the memory saving 'sequential attention' are not detailed, Christoph Reich has an impl at
|
||||||
|
GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial.
|
||||||
|
* num_heads per stage is not detailed for Huge and Giant model variants
|
||||||
|
* 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts
|
||||||
|
|
||||||
|
Noteworthy additions over official Swin v1:
|
||||||
|
* MLP relative position embedding is looking promising and adapts to different image/window sizes
|
||||||
|
* This impl has been designed to allow easy change of image size with matching window size changes
|
||||||
|
* Non-square image size and window size are supported
|
||||||
|
|
||||||
|
Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
|
||||||
|
"""
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# Swin Transformer V2 reimplementation
|
||||||
|
# Copyright (c) 2021 Christoph Reich
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Written by Christoph Reich
|
||||||
|
# --------------------------------------------------------
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Tuple, Optional, List, Union, Any, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .fx_features import register_notrace_function
|
||||||
|
from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply
|
||||||
|
from .layers import DropPath, Mlp, to_2tuple, _assert
|
||||||
|
from .registry import register_model
|
||||||
|
from .vision_transformer import checkpoint_filter_fn
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000,
|
||||||
|
'input_size': (3, 224, 224),
|
||||||
|
'pool_size': None,
|
||||||
|
'crop_pct': 0.9,
|
||||||
|
'interpolation': 'bicubic',
|
||||||
|
'fixed_input_size': True,
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN,
|
||||||
|
'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'patch_embed.proj',
|
||||||
|
'classifier': 'head',
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
# patch models (my experiments)
|
||||||
|
'swin_v2_cr_tiny_384': _cfg(
|
||||||
|
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'swin_v2_cr_tiny_224': _cfg(
|
||||||
|
url="", input_size=(3, 224, 224), crop_pct=1.0),
|
||||||
|
'swin_v2_cr_small_384': _cfg(
|
||||||
|
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'swin_v2_cr_small_224': _cfg(
|
||||||
|
url="", input_size=(3, 224, 224), crop_pct=1.0),
|
||||||
|
'swin_v2_cr_base_384': _cfg(
|
||||||
|
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'swin_v2_cr_base_224': _cfg(
|
||||||
|
url="", input_size=(3, 224, 224), crop_pct=1.0),
|
||||||
|
'swin_v2_cr_large_384': _cfg(
|
||||||
|
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'swin_v2_cr_large_224': _cfg(
|
||||||
|
url="", input_size=(3, 224, 224), crop_pct=1.0),
|
||||||
|
'swin_v2_cr_huge_384': _cfg(
|
||||||
|
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'swin_v2_cr_huge_224': _cfg(
|
||||||
|
url="", input_size=(3, 224, 224), crop_pct=1.0),
|
||||||
|
'swin_v2_cr_giant_384': _cfg(
|
||||||
|
url="", input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'swin_v2_cr_giant_224': _cfg(
|
||||||
|
url="", input_size=(3, 224, 224), crop_pct=1.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). """
|
||||||
|
return x.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W). """
|
||||||
|
return x.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def window_partition(x, window_size: Tuple[int, int]):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, H, W, C)
|
||||||
|
window_size (int): window size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
windows: (num_windows*B, window_size, window_size, C)
|
||||||
|
"""
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
|
||||||
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
@register_notrace_function # reason: int argument is a Proxy
|
||||||
|
def window_reverse(windows, window_size: tuple[int, int], img_size: tuple[int, int]):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
windows: (num_windows * B, window_size[0], window_size[1], C)
|
||||||
|
window_size (Tuple[int, int]): Window size
|
||||||
|
img_size (Tuple[int, int]): Image size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x: (B, H, W, C)
|
||||||
|
"""
|
||||||
|
H, W = img_size
|
||||||
|
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
|
||||||
|
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
|
||||||
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WindowMultiHeadAttention(nn.Module):
|
||||||
|
r"""This class implements window-based Multi-Head-Attention with log-spaced continuous position bias.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input features
|
||||||
|
window_size (int): Window size
|
||||||
|
num_heads (int): Number of attention heads
|
||||||
|
drop_attn (float): Dropout rate of attention map
|
||||||
|
drop_proj (float): Dropout rate after projection
|
||||||
|
meta_hidden_dim (int): Number of hidden features in the two layer MLP meta network
|
||||||
|
sequential_attn (bool): If true sequential self-attention is performed
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
window_size: Tuple[int, int],
|
||||||
|
drop_attn: float = 0.0,
|
||||||
|
drop_proj: float = 0.0,
|
||||||
|
meta_hidden_dim: int = 384, # FIXME what's the optimal value?
|
||||||
|
sequential_attn: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super(WindowMultiHeadAttention, self).__init__()
|
||||||
|
assert dim % num_heads == 0, \
|
||||||
|
"The number of input features (in_features) are not divisible by the number of heads (num_heads)."
|
||||||
|
self.in_features: int = dim
|
||||||
|
self.window_size: Tuple[int, int] = window_size
|
||||||
|
self.num_heads: int = num_heads
|
||||||
|
self.sequential_attn: bool = sequential_attn
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(in_features=dim, out_features=dim * 3, bias=True)
|
||||||
|
self.attn_drop = nn.Dropout(drop_attn)
|
||||||
|
self.proj = nn.Linear(in_features=dim, out_features=dim, bias=True)
|
||||||
|
self.proj_drop = nn.Dropout(drop_proj)
|
||||||
|
# meta network for positional encodings
|
||||||
|
self.meta_mlp = Mlp(
|
||||||
|
2, # x, y
|
||||||
|
hidden_features=meta_hidden_dim,
|
||||||
|
out_features=num_heads,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
drop=0. # FIXME should we add stochasticity?
|
||||||
|
)
|
||||||
|
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads)))
|
||||||
|
self._make_pair_wise_relative_positions()
|
||||||
|
|
||||||
|
def _make_pair_wise_relative_positions(self) -> None:
|
||||||
|
"""Method initializes the pair-wise relative positions to compute the positional biases."""
|
||||||
|
device = self.tau.device
|
||||||
|
coordinates = torch.stack(torch.meshgrid([
|
||||||
|
torch.arange(self.window_size[0], device=device),
|
||||||
|
torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1)
|
||||||
|
relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :]
|
||||||
|
relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float()
|
||||||
|
relative_coordinates_log = torch.sign(relative_coordinates) * torch.log(
|
||||||
|
1.0 + relative_coordinates.abs())
|
||||||
|
self.register_buffer("relative_coordinates_log", relative_coordinates_log, persistent=False)
|
||||||
|
|
||||||
|
def update_input_size(self, new_window_size: int, **kwargs: Any) -> None:
|
||||||
|
"""Method updates the window size and so the pair-wise relative positions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_window_size (int): New window size
|
||||||
|
kwargs (Any): Unused
|
||||||
|
"""
|
||||||
|
# Set new window size and new pair-wise relative positions
|
||||||
|
self.window_size: int = new_window_size
|
||||||
|
self._make_pair_wise_relative_positions()
|
||||||
|
|
||||||
|
def _relative_positional_encodings(self) -> torch.Tensor:
|
||||||
|
"""Method computes the relative positional encodings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
relative_position_bias (torch.Tensor): Relative positional encodings
|
||||||
|
(1, number of heads, window size ** 2, window size ** 2)
|
||||||
|
"""
|
||||||
|
window_area = self.window_size[0] * self.window_size[1]
|
||||||
|
relative_position_bias = self.meta_mlp(self.relative_coordinates_log)
|
||||||
|
relative_position_bias = relative_position_bias.transpose(1, 0).reshape(
|
||||||
|
self.num_heads, window_area, window_area
|
||||||
|
)
|
||||||
|
relative_position_bias = relative_position_bias.unsqueeze(0)
|
||||||
|
return relative_position_bias
|
||||||
|
|
||||||
|
def _forward_sequential(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
# FIXME TODO figure out 'sequential' attention mentioned in paper (should reduce GPU memory)
|
||||||
|
assert False, "not implemented"
|
||||||
|
|
||||||
|
def _forward_batch(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""This function performs standard (non-sequential) scaled cosine self-attention.
|
||||||
|
"""
|
||||||
|
Bw, L, C = x.shape
|
||||||
|
|
||||||
|
qkv = self.qkv(x).view(Bw, L, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
query, key, value = qkv.unbind(0)
|
||||||
|
|
||||||
|
# compute attention map with scaled cosine attention
|
||||||
|
denom = torch.norm(query, dim=-1, keepdim=True) @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1)
|
||||||
|
attn = query @ key.transpose(-2, -1) / denom.clamp(min=1e-6)
|
||||||
|
attn = attn / self.tau.clamp(min=0.01).reshape(1, self.num_heads, 1, 1)
|
||||||
|
attn = attn + self._relative_positional_encodings()
|
||||||
|
if mask is not None:
|
||||||
|
# Apply mask if utilized
|
||||||
|
num_win: int = mask.shape[0]
|
||||||
|
attn = attn.view(Bw // num_win, num_win, self.num_heads, L, L)
|
||||||
|
attn = attn + mask.unsqueeze(1).unsqueeze(0)
|
||||||
|
attn = attn.view(-1, self.num_heads, L, L)
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
x = (attn @ value).transpose(1, 2).reshape(Bw, L, -1)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
""" Forward pass.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of the shape (B * windows, N, C)
|
||||||
|
mask (Optional[torch.Tensor]): Attention mask for the shift case
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output tensor of the shape [B * windows, N, C]
|
||||||
|
"""
|
||||||
|
if self.sequential_attn:
|
||||||
|
return self._forward_sequential(x, mask)
|
||||||
|
else:
|
||||||
|
return self._forward_batch(x, mask)
|
||||||
|
|
||||||
|
|
||||||
|
class SwinTransformerBlock(nn.Module):
|
||||||
|
r"""This class implements the Swin transformer block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels
|
||||||
|
num_heads (int): Number of attention heads to be utilized
|
||||||
|
feat_size (Tuple[int, int]): Input resolution
|
||||||
|
window_size (Tuple[int, int]): Window size to be utilized
|
||||||
|
shift_size (int): Shifting size to be used
|
||||||
|
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
|
||||||
|
drop (float): Dropout in input mapping
|
||||||
|
drop_attn (float): Dropout rate of attention map
|
||||||
|
drop_path (float): Dropout in main path
|
||||||
|
extra_norm (bool): Insert extra norm on 'main' branch if True
|
||||||
|
sequential_attn (bool): If true sequential self-attention is performed
|
||||||
|
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
feat_size: Tuple[int, int],
|
||||||
|
window_size: Tuple[int, int],
|
||||||
|
shift_size: Tuple[int, int] = (0, 0),
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
drop: float = 0.0,
|
||||||
|
drop_attn: float = 0.0,
|
||||||
|
drop_path: float = 0.0,
|
||||||
|
extra_norm: bool = False,
|
||||||
|
sequential_attn: bool = False,
|
||||||
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
|
) -> None:
|
||||||
|
super(SwinTransformerBlock, self).__init__()
|
||||||
|
self.dim: int = dim
|
||||||
|
self.feat_size: Tuple[int, int] = feat_size
|
||||||
|
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
|
||||||
|
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
|
||||||
|
self.window_area = self.window_size[0] * self.window_size[1]
|
||||||
|
|
||||||
|
# attn branch
|
||||||
|
self.attn = WindowMultiHeadAttention(
|
||||||
|
dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
window_size=self.window_size,
|
||||||
|
drop_attn=drop_attn,
|
||||||
|
drop_proj=drop,
|
||||||
|
sequential_attn=sequential_attn,
|
||||||
|
)
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.drop_path1 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
# mlp branch
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=int(dim * mlp_ratio),
|
||||||
|
drop=drop,
|
||||||
|
out_features=dim,
|
||||||
|
)
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
# extra norm layer mentioned for Huge/Giant models in V2 paper (FIXME may be in wrong spot?)
|
||||||
|
self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()
|
||||||
|
|
||||||
|
self._make_attention_mask()
|
||||||
|
|
||||||
|
def _calc_window_shift(self, target_window_size):
|
||||||
|
window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)]
|
||||||
|
shift_size = [0 if f <= w else s for f, w, s in zip(self.feat_size, window_size, self.target_shift_size)]
|
||||||
|
return tuple(window_size), tuple(shift_size)
|
||||||
|
|
||||||
|
def _make_attention_mask(self) -> None:
|
||||||
|
"""Method generates the attention mask used in shift case."""
|
||||||
|
# Make masks for shift case
|
||||||
|
if any(self.shift_size):
|
||||||
|
# calculate attention mask for SW-MSA
|
||||||
|
H, W = self.feat_size
|
||||||
|
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||||
|
cnt = 0
|
||||||
|
for h in (
|
||||||
|
slice(0, -self.window_size[0]),
|
||||||
|
slice(-self.window_size[0], -self.shift_size[0]),
|
||||||
|
slice(-self.shift_size[0], None)):
|
||||||
|
for w in (
|
||||||
|
slice(0, -self.window_size[1]),
|
||||||
|
slice(-self.window_size[1], -self.shift_size[1]),
|
||||||
|
slice(-self.shift_size[1], None)):
|
||||||
|
img_mask[:, h, w, :] = cnt
|
||||||
|
cnt += 1
|
||||||
|
mask_windows = window_partition(img_mask, self.window_size) # num_windows, window_size, window_size, 1
|
||||||
|
mask_windows = mask_windows.view(-1, self.window_area)
|
||||||
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||||
|
else:
|
||||||
|
attn_mask = None
|
||||||
|
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
||||||
|
|
||||||
|
def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None:
|
||||||
|
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_window_size (int): New window size
|
||||||
|
new_feat_size (Tuple[int, int]): New input resolution
|
||||||
|
"""
|
||||||
|
# Update input resolution
|
||||||
|
self.feat_size: Tuple[int, int] = new_feat_size
|
||||||
|
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(new_window_size))
|
||||||
|
self.window_area = self.window_size[0] * self.window_size[1]
|
||||||
|
self.attn.update_input_size(new_window_size=self.window_size)
|
||||||
|
self._make_attention_mask()
|
||||||
|
|
||||||
|
def _shifted_window_attn(self, x):
|
||||||
|
H, W = self.feat_size
|
||||||
|
B, L, C = x.shape
|
||||||
|
x = x.view(B, H, W, C)
|
||||||
|
|
||||||
|
# cyclic shift
|
||||||
|
if any(self.shift_size):
|
||||||
|
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
|
||||||
|
else:
|
||||||
|
shifted_x = x
|
||||||
|
|
||||||
|
# partition windows
|
||||||
|
x_windows = window_partition(shifted_x, self.window_size) # num_windows * B, window_size, window_size, C
|
||||||
|
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
|
||||||
|
|
||||||
|
# W-MSA/SW-MSA
|
||||||
|
attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_windows * B, window_size * window_size, C
|
||||||
|
|
||||||
|
# merge windows
|
||||||
|
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
|
||||||
|
shifted_x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
|
||||||
|
|
||||||
|
# reverse cyclic shift
|
||||||
|
if any(self.shift_size):
|
||||||
|
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
|
||||||
|
else:
|
||||||
|
x = shifted_x
|
||||||
|
|
||||||
|
x = x.view(B, L, C)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of the shape [B, C, H, W]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output (torch.Tensor): Output tensor of the shape [B, C, H, W]
|
||||||
|
"""
|
||||||
|
# NOTE post-norm branches (op -> norm -> drop)
|
||||||
|
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
|
||||||
|
x = x + self.drop_path2(self.norm2(self.mlp(x)))
|
||||||
|
x = self.norm3(x) # main-branch norm enabled for some blocks (every 6 for Huge/Giant)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchMerging(nn.Module):
|
||||||
|
""" This class implements the patch merging as a strided convolution with a normalization before.
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels
|
||||||
|
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, norm_layer: Type[nn.Module] = nn.LayerNorm) -> None:
|
||||||
|
super(PatchMerging, self).__init__()
|
||||||
|
self.norm = norm_layer(4 * dim)
|
||||||
|
self.reduction = nn.Linear(in_features=4 * dim, out_features=2 * dim, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
""" Forward pass.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of the shape [B, C, H, W]
|
||||||
|
Returns:
|
||||||
|
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
|
||||||
|
"""
|
||||||
|
x = bchw_to_bhwc(x).unfold(dimension=1, size=2, step=2).unfold(dimension=2, size=2, step=2)
|
||||||
|
x = x.permute(0, 1, 2, 5, 4, 3).flatten(3) # permute maintains compat with ch order in official swin impl
|
||||||
|
x = self.norm(x)
|
||||||
|
x = bhwc_to_bchw(self.reduction(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
""" 2D Image to Patch Embedding """
|
||||||
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
|
||||||
|
super().__init__()
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
patch_size = to_2tuple(patch_size)
|
||||||
|
self.img_size = img_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||||
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||||
|
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwinTransformerStage(nn.Module):
|
||||||
|
r"""This class implements a stage of the Swin transformer including multiple layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim (int): Number of input channels
|
||||||
|
depth (int): Depth of the stage (number of layers)
|
||||||
|
downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper)
|
||||||
|
feat_size (Tuple[int, int]): input feature map size (H, W)
|
||||||
|
num_heads (int): Number of attention heads to be utilized
|
||||||
|
window_size (int): Window size to be utilized
|
||||||
|
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
|
||||||
|
drop (float): Dropout in input mapping
|
||||||
|
drop_attn (float): Dropout rate of attention map
|
||||||
|
drop_path (float): Dropout in main path
|
||||||
|
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
|
||||||
|
grad_checkpointing (bool): If true checkpointing is utilized
|
||||||
|
extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks
|
||||||
|
sequential_attn (bool): If true sequential self-attention is performed
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
depth: int,
|
||||||
|
downscale: bool,
|
||||||
|
num_heads: int,
|
||||||
|
feat_size: Tuple[int, int],
|
||||||
|
window_size: Tuple[int, int],
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
drop: float = 0.0,
|
||||||
|
drop_attn: float = 0.0,
|
||||||
|
drop_path: Union[List[float], float] = 0.0,
|
||||||
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
|
grad_checkpointing: bool = False,
|
||||||
|
extra_norm_period: int = 0,
|
||||||
|
sequential_attn: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super(SwinTransformerStage, self).__init__()
|
||||||
|
self.downscale: bool = downscale
|
||||||
|
self.grad_checkpointing: bool = grad_checkpointing
|
||||||
|
self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size
|
||||||
|
|
||||||
|
self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer) if downscale else nn.Identity()
|
||||||
|
|
||||||
|
embed_dim = embed_dim * 2 if downscale else embed_dim
|
||||||
|
self.blocks = nn.Sequential(*[
|
||||||
|
SwinTransformerBlock(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
feat_size=self.feat_size,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]),
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
drop=drop,
|
||||||
|
drop_attn=drop_attn,
|
||||||
|
drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
|
||||||
|
extra_norm=not (index + 1) % extra_norm_period if extra_norm_period else False,
|
||||||
|
sequential_attn=sequential_attn,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
for index in range(depth)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_input_size(self, new_window_size: int, new_feat_size: Tuple[int, int]) -> None:
|
||||||
|
"""Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_window_size (int): New window size
|
||||||
|
new_feat_size (Tuple[int, int]): New input resolution
|
||||||
|
"""
|
||||||
|
self.feat_size: Tuple[int, int] = (
|
||||||
|
(new_feat_size[0] // 2, new_feat_size[1] // 2) if self.downscale else new_feat_size
|
||||||
|
)
|
||||||
|
for block in self.blocks:
|
||||||
|
block.update_input_size(new_window_size=new_window_size, new_feat_size=self.feat_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Forward pass.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of the shape [B, C, H, W] or [B, L, C]
|
||||||
|
Returns:
|
||||||
|
output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
|
||||||
|
"""
|
||||||
|
x = self.downsample(x)
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
L = H * W
|
||||||
|
|
||||||
|
x = bchw_to_bhwc(x).reshape(B, L, C)
|
||||||
|
for block in self.blocks:
|
||||||
|
# Perform checkpointing if utilized
|
||||||
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
|
x = checkpoint.checkpoint(block, x)
|
||||||
|
else:
|
||||||
|
x = block(x)
|
||||||
|
x = bhwc_to_bchw(x.reshape(B, H, W, -1))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwinTransformerV2Cr(nn.Module):
|
||||||
|
r""" Swin Transformer V2
|
||||||
|
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` -
|
||||||
|
https://arxiv.org/pdf/2111.09883
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size (Tuple[int, int]): Input resolution.
|
||||||
|
window_size (Optional[int]): Window size. If None, img_size // window_div. Default: None
|
||||||
|
img_window_ratio (int): Window size to image size ratio. Default: 32
|
||||||
|
patch_size (int | tuple(int)): Patch size. Default: 4
|
||||||
|
in_chans (int): Number of input channels.
|
||||||
|
depths (int): Depth of the stage (number of layers).
|
||||||
|
num_heads (int): Number of attention heads to be utilized.
|
||||||
|
embed_dim (int): Patch embedding dimension. Default: 96
|
||||||
|
num_classes (int): Number of output classes. Default: 1000
|
||||||
|
mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels. Default: 4
|
||||||
|
drop_rate (float): Dropout rate. Default: 0.0
|
||||||
|
attn_drop_rate (float): Dropout rate of attention map. Default: 0.0
|
||||||
|
drop_path_rate (float): Stochastic depth rate. Default: 0.0
|
||||||
|
norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
|
||||||
|
grad_checkpointing (bool): If true checkpointing is utilized. Default: False
|
||||||
|
sequential_attn (bool): If true sequential self-attention is performed. Default: False
|
||||||
|
use_deformable (bool): If true deformable block is used. Default: False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size: Tuple[int, int] = (224, 224),
|
||||||
|
patch_size: int = 4,
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
img_window_ratio: int = 32,
|
||||||
|
in_chans: int = 3,
|
||||||
|
num_classes: int = 1000,
|
||||||
|
embed_dim: int = 96,
|
||||||
|
depths: Tuple[int, ...] = (2, 2, 6, 2),
|
||||||
|
num_heads: Tuple[int, ...] = (3, 6, 12, 24),
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
drop_rate: float = 0.0,
|
||||||
|
attn_drop_rate: float = 0.0,
|
||||||
|
drop_path_rate: float = 0.0,
|
||||||
|
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||||
|
grad_checkpointing: bool = False,
|
||||||
|
extra_norm_period: int = 0,
|
||||||
|
sequential_attn: bool = False,
|
||||||
|
global_pool: str = 'avg',
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
super(SwinTransformerV2Cr, self).__init__()
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
window_size = tuple([
|
||||||
|
s // img_window_ratio for s in img_size]) if window_size is None else to_2tuple(window_size)
|
||||||
|
|
||||||
|
self.num_classes: int = num_classes
|
||||||
|
self.patch_size: int = patch_size
|
||||||
|
self.img_size: Tuple[int, int] = img_size
|
||||||
|
self.window_size: int = window_size
|
||||||
|
self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))
|
||||||
|
|
||||||
|
self.patch_embed: nn.Module = PatchEmbed(
|
||||||
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
||||||
|
embed_dim=embed_dim, norm_layer=norm_layer)
|
||||||
|
patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size
|
||||||
|
|
||||||
|
drop_path_rate = torch.linspace(0.0, drop_path_rate, sum(depths)).tolist()
|
||||||
|
stages = []
|
||||||
|
for index, (depth, num_heads) in enumerate(zip(depths, num_heads)):
|
||||||
|
stage_scale = 2 ** max(index - 1, 0)
|
||||||
|
stages.append(
|
||||||
|
SwinTransformerStage(
|
||||||
|
embed_dim=embed_dim * stage_scale,
|
||||||
|
depth=depth,
|
||||||
|
downscale=index != 0,
|
||||||
|
feat_size=(patch_grid_size[0] // stage_scale, patch_grid_size[1] // stage_scale),
|
||||||
|
num_heads=num_heads,
|
||||||
|
window_size=window_size,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
drop=drop_rate,
|
||||||
|
drop_attn=attn_drop_rate,
|
||||||
|
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
|
||||||
|
grad_checkpointing=grad_checkpointing,
|
||||||
|
extra_norm_period=extra_norm_period,
|
||||||
|
sequential_attn=sequential_attn,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.stages = nn.Sequential(*stages)
|
||||||
|
|
||||||
|
self.global_pool: str = global_pool
|
||||||
|
self.head: nn.Module = nn.Linear(
|
||||||
|
in_features=self.num_features, out_features=num_classes) if num_classes else nn.Identity()
|
||||||
|
|
||||||
|
# FIXME weight init TBD, PyTorch default init appears to be working well,
|
||||||
|
# but differs from usual ViT or Swin init.
|
||||||
|
# named_apply(init_weights, self)
|
||||||
|
|
||||||
|
def update_input_size(
|
||||||
|
self,
|
||||||
|
new_img_size: Optional[Tuple[int, int]] = None,
|
||||||
|
new_window_size: Optional[int] = None,
|
||||||
|
img_window_ratio: int = 32,
|
||||||
|
) -> None:
|
||||||
|
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_window_size (Optional[int]): New window size, if None based on new_img_size // window_div
|
||||||
|
new_img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
|
||||||
|
img_window_ratio (int): divisor for calculating window size from image size
|
||||||
|
"""
|
||||||
|
# Check parameters
|
||||||
|
if new_img_size is None:
|
||||||
|
new_img_size = self.img_size
|
||||||
|
else:
|
||||||
|
new_img_size = to_2tuple(new_img_size)
|
||||||
|
if new_window_size is None:
|
||||||
|
new_window_size = tuple([s // img_window_ratio for s in new_img_size])
|
||||||
|
# Compute new patch resolution & update resolution of each stage
|
||||||
|
new_patch_grid_size = (new_img_size[0] // self.patch_size, new_img_size[1] // self.patch_size)
|
||||||
|
for index, stage in enumerate(self.stages):
|
||||||
|
stage_scale = 2 ** max(index - 1, 0)
|
||||||
|
stage.update_input_size(
|
||||||
|
new_window_size=new_window_size,
|
||||||
|
new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_classifier(self) -> nn.Module:
|
||||||
|
"""Method returns the classification head of the model.
|
||||||
|
Returns:
|
||||||
|
head (nn.Module): Current classification head
|
||||||
|
"""
|
||||||
|
head: nn.Module = self.head
|
||||||
|
return head
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
|
||||||
|
"""Method results the classification head
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_classes (int): Number of classes to be predicted
|
||||||
|
global_pool (str): Unused
|
||||||
|
"""
|
||||||
|
self.num_classes: int = num_classes
|
||||||
|
if global_pool is not None:
|
||||||
|
self.global_pool = global_pool
|
||||||
|
self.head: nn.Module = nn.Linear(
|
||||||
|
in_features=self.num_features, out_features=num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x = self.stages(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
|
if self.global_pool == 'avg':
|
||||||
|
x = x.mean(dim=(2, 3))
|
||||||
|
return x if pre_logits else self.head(x)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.forward_head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(module: nn.Module, name: str = ''):
|
||||||
|
# FIXME WIP
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
if 'qkv' in name:
|
||||||
|
# treat the weights of Q, K, V separately
|
||||||
|
val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
|
||||||
|
nn.init.uniform_(module.weight, -val, val)
|
||||||
|
else:
|
||||||
|
nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_swin_transformer_v2_cr(variant, pretrained=False, default_cfg=None, **kwargs):
|
||||||
|
if default_cfg is None:
|
||||||
|
default_cfg = deepcopy(default_cfgs[variant])
|
||||||
|
overlay_external_default_cfg(default_cfg, kwargs)
|
||||||
|
default_num_classes = default_cfg['num_classes']
|
||||||
|
default_img_size = default_cfg['input_size'][-2:]
|
||||||
|
|
||||||
|
num_classes = kwargs.pop('num_classes', default_num_classes)
|
||||||
|
img_size = kwargs.pop('img_size', default_img_size)
|
||||||
|
if kwargs.get('features_only', None):
|
||||||
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||||
|
|
||||||
|
model = build_model_with_cfg(
|
||||||
|
SwinTransformerV2Cr,
|
||||||
|
variant,
|
||||||
|
pretrained,
|
||||||
|
default_cfg=default_cfg,
|
||||||
|
img_size=img_size,
|
||||||
|
num_classes=num_classes,
|
||||||
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def swin_v2_cr_tiny_384(pretrained=False, **kwargs):
|
||||||
|
"""Swin-T V2 CR @ 384x384, trained ImageNet-1k"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
embed_dim=96,
|
||||||
|
depths=(2, 2, 6, 2),
|
||||||
|
num_heads=(3, 6, 12, 24),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_384', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def swin_v2_cr_tiny_224(pretrained=False, **kwargs):
|
||||||
|
"""Swin-T V2 CR @ 224x224, trained ImageNet-1k"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
embed_dim=96,
|
||||||
|
depths=(2, 2, 6, 2),
|
||||||
|
num_heads=(3, 6, 12, 24),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def swin_v2_cr_small_384(pretrained=False, **kwargs):
|
||||||
|
"""Swin-S V2 CR @ 384x384, trained ImageNet-1k"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
embed_dim=96,
|
||||||
|
depths=(2, 2, 18, 2),
|
||||||
|
num_heads=(3, 6, 12, 24),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_swin_transformer_v2_cr('swin_v2_cr_small_384', pretrained=pretrained, **model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def swin_v2_cr_small_224(pretrained=False, **kwargs):
|
||||||
|
"""Swin-S V2 CR @ 224x224, trained ImageNet-1k"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
embed_dim=96,
|
||||||
|
depths=(2, 2, 18, 2),
|
||||||
|
num_heads=(3, 6, 12, 24),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def swin_v2_cr_base_384(pretrained=False, **kwargs):
|
||||||
|
"""Swin-B V2 CR @ 384x384, trained ImageNet-1k"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
embed_dim=128,
|
||||||
|
depths=(2, 2, 18, 2),
|
||||||
|
num_heads=(4, 8, 16, 32),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_swin_transformer_v2_cr('swin_v2_cr_base_384', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def swin_v2_cr_base_224(pretrained=False, **kwargs):
|
||||||
|
"""Swin-B V2 CR @ 224x224, trained ImageNet-1k"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
embed_dim=128,
|
||||||
|
depths=(2, 2, 18, 2),
|
||||||
|
num_heads=(4, 8, 16, 32),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def swin_v2_cr_large_384(pretrained=False, **kwargs):
|
||||||
|
"""Swin-L V2 CR @ 384x384, trained ImageNet-1k"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
embed_dim=192,
|
||||||
|
depths=(2, 2, 18, 2),
|
||||||
|
num_heads=(6, 12, 24, 48),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_swin_transformer_v2_cr('swin_v2_cr_large_384', pretrained=pretrained, **model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def swin_v2_cr_large_224(pretrained=False, **kwargs):
|
||||||
|
"""Swin-L V2 CR @ 224x224, trained ImageNet-1k"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
embed_dim=192,
|
||||||
|
depths=(2, 2, 18, 2),
|
||||||
|
num_heads=(6, 12, 24, 48),
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_swin_transformer_v2_cr('swin_v2_cr_large_224', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def swin_v2_cr_huge_384(pretrained=False, **kwargs):
|
||||||
|
"""Swin-H V2 CR @ 384x384, trained ImageNet-1k"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
embed_dim=352,
|
||||||
|
depths=(2, 2, 18, 2),
|
||||||
|
num_heads=(11, 22, 44, 88), # head count not certain for Huge, 384 & 224 trying diff values
|
||||||
|
extra_norm_period=6,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_swin_transformer_v2_cr('swin_v2_cr_huge_384', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def swin_v2_cr_huge_224(pretrained=False, **kwargs):
|
||||||
|
"""Swin-H V2 CR @ 224x224, trained ImageNet-1k"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
embed_dim=352,
|
||||||
|
depths=(2, 2, 18, 2),
|
||||||
|
num_heads=(8, 16, 32, 64), # head count not certain for Huge, 384 & 224 trying diff values
|
||||||
|
extra_norm_period=6,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_swin_transformer_v2_cr('swin_v2_cr_huge_224', pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def swin_v2_cr_giant_384(pretrained=False, **kwargs):
|
||||||
|
"""Swin-G V2 CR @ 384x384, trained ImageNet-1k"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
embed_dim=512,
|
||||||
|
depths=(2, 2, 42, 2),
|
||||||
|
num_heads=(16, 32, 64, 128),
|
||||||
|
extra_norm_period=6,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_384', pretrained=pretrained, **model_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def swin_v2_cr_giant_224(pretrained=False, **kwargs):
|
||||||
|
"""Swin-G V2 CR @ 224x224, trained ImageNet-1k"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
embed_dim=512,
|
||||||
|
depths=(2, 2, 42, 2),
|
||||||
|
num_heads=(16, 32, 64, 128),
|
||||||
|
extra_norm_period=6,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_224', pretrained=pretrained, **model_kwargs)
|
Loading…
Reference in new issue