You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
651 lines
27 KiB
651 lines
27 KiB
""" Swin Transformer
|
|
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
|
|
- https://arxiv.org/pdf/2103.14030
|
|
|
|
Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
|
|
|
|
"""
|
|
# --------------------------------------------------------
|
|
# Swin Transformer
|
|
# Copyright (c) 2021 Microsoft
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
# Written by Ze Liu
|
|
# --------------------------------------------------------
|
|
import logging
|
|
import math
|
|
from copy import deepcopy
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.checkpoint as checkpoint
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
|
from .layers import DropPath, to_2tuple, trunc_normal_
|
|
from .registry import register_model
|
|
from .vision_transformer import checkpoint_filter_fn, Mlp, PatchEmbed, _init_vit_weights
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
return {
|
|
'url': url,
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
|
**kwargs
|
|
}
|
|
|
|
|
|
default_cfgs = {
|
|
# patch models (my experiments)
|
|
'swin_base_patch4_window12_384': _cfg(
|
|
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
'swin_base_patch4_window7_224': _cfg(
|
|
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
|
|
),
|
|
|
|
'swin_large_patch4_window12_384': _cfg(
|
|
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
'swin_large_patch4_window7_224': _cfg(
|
|
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
|
|
),
|
|
|
|
'swin_small_patch4_window7_224': _cfg(
|
|
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
|
|
),
|
|
|
|
'swin_tiny_patch4_window7_224': _cfg(
|
|
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
|
|
),
|
|
|
|
'swin_base_patch4_window12_384_in22k': _cfg(
|
|
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
|
|
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
|
|
|
|
'swin_base_patch4_window7_224_in22k': _cfg(
|
|
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
|
|
num_classes=21841),
|
|
|
|
'swin_large_patch4_window12_384_in22k': _cfg(
|
|
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
|
|
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
|
|
|
|
'swin_large_patch4_window7_224_in22k': _cfg(
|
|
url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
|
|
num_classes=21841),
|
|
|
|
}
|
|
|
|
|
|
def window_partition(x, window_size: int):
|
|
"""
|
|
Args:
|
|
x: (B, H, W, C)
|
|
window_size (int): window size
|
|
|
|
Returns:
|
|
windows: (num_windows*B, window_size, window_size, C)
|
|
"""
|
|
B, H, W, C = x.shape
|
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
|
return windows
|
|
|
|
|
|
def window_reverse(windows, window_size: int, H: int, W: int):
|
|
"""
|
|
Args:
|
|
windows: (num_windows*B, window_size, window_size, C)
|
|
window_size (int): Window size
|
|
H (int): Height of image
|
|
W (int): Width of image
|
|
|
|
Returns:
|
|
x: (B, H, W, C)
|
|
"""
|
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
|
return x
|
|
|
|
|
|
class WindowAttention(nn.Module):
|
|
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
|
It supports both of shifted and non-shifted window.
|
|
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
window_size (tuple[int]): The height and width of the window.
|
|
num_heads (int): Number of attention heads.
|
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
|
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
|
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
|
"""
|
|
|
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
|
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.window_size = window_size # Wh, Ww
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
# define a parameter table of relative position bias
|
|
self.relative_position_bias_table = nn.Parameter(
|
|
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
|
|
|
# get pair-wise relative position index for each token inside the window
|
|
coords_h = torch.arange(self.window_size[0])
|
|
coords_w = torch.arange(self.window_size[1])
|
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
|
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
|
relative_coords[:, :, 1] += self.window_size[1] - 1
|
|
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
|
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
|
self.register_buffer("relative_position_index", relative_position_index)
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
self.proj = nn.Linear(dim, dim)
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
trunc_normal_(self.relative_position_bias_table, std=.02)
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
def forward(self, x, mask: Optional[torch.Tensor] = None):
|
|
"""
|
|
Args:
|
|
x: input features with shape of (num_windows*B, N, C)
|
|
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
|
"""
|
|
B_, N, C = x.shape
|
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
q = q * self.scale
|
|
attn = (q @ k.transpose(-2, -1))
|
|
|
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
|
attn = attn + relative_position_bias.unsqueeze(0)
|
|
|
|
if mask is not None:
|
|
nW = mask.shape[0]
|
|
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
|
attn = attn.view(-1, self.num_heads, N, N)
|
|
attn = self.softmax(attn)
|
|
else:
|
|
attn = self.softmax(attn)
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
return x
|
|
|
|
|
|
class SwinTransformerBlock(nn.Module):
|
|
r""" Swin Transformer Block.
|
|
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
input_resolution (tuple[int]): Input resulotion.
|
|
num_heads (int): Number of attention heads.
|
|
window_size (int): Window size.
|
|
shift_size (int): Shift size for SW-MSA.
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
|
drop (float, optional): Dropout rate. Default: 0.0
|
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
|
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
|
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
"""
|
|
|
|
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
|
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
|
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.input_resolution = input_resolution
|
|
self.num_heads = num_heads
|
|
self.window_size = window_size
|
|
self.shift_size = shift_size
|
|
self.mlp_ratio = mlp_ratio
|
|
if min(self.input_resolution) <= self.window_size:
|
|
# if window size is larger than input resolution, we don't partition windows
|
|
self.shift_size = 0
|
|
self.window_size = min(self.input_resolution)
|
|
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
self.attn = WindowAttention(
|
|
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
self.norm2 = norm_layer(dim)
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
|
|
if self.shift_size > 0:
|
|
# calculate attention mask for SW-MSA
|
|
H, W = self.input_resolution
|
|
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
|
h_slices = (slice(0, -self.window_size),
|
|
slice(-self.window_size, -self.shift_size),
|
|
slice(-self.shift_size, None))
|
|
w_slices = (slice(0, -self.window_size),
|
|
slice(-self.window_size, -self.shift_size),
|
|
slice(-self.shift_size, None))
|
|
cnt = 0
|
|
for h in h_slices:
|
|
for w in w_slices:
|
|
img_mask[:, h, w, :] = cnt
|
|
cnt += 1
|
|
|
|
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
|
else:
|
|
attn_mask = None
|
|
|
|
self.register_buffer("attn_mask", attn_mask)
|
|
|
|
def forward(self, x):
|
|
H, W = self.input_resolution
|
|
B, L, C = x.shape
|
|
assert L == H * W, "input feature has wrong size"
|
|
|
|
shortcut = x
|
|
x = self.norm1(x)
|
|
x = x.view(B, H, W, C)
|
|
|
|
# cyclic shift
|
|
if self.shift_size > 0:
|
|
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
|
else:
|
|
shifted_x = x
|
|
|
|
# partition windows
|
|
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
|
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
|
|
|
# W-MSA/SW-MSA
|
|
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
|
|
|
# merge windows
|
|
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
|
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
|
|
|
# reverse cyclic shift
|
|
if self.shift_size > 0:
|
|
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
|
else:
|
|
x = shifted_x
|
|
x = x.view(B, H * W, C)
|
|
|
|
# FFN
|
|
x = shortcut + self.drop_path(x)
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
return x
|
|
|
|
|
|
class PatchMerging(nn.Module):
|
|
r""" Patch Merging Layer.
|
|
|
|
Args:
|
|
input_resolution (tuple[int]): Resolution of input feature.
|
|
dim (int): Number of input channels.
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
"""
|
|
|
|
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
|
super().__init__()
|
|
self.input_resolution = input_resolution
|
|
self.dim = dim
|
|
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
|
self.norm = norm_layer(4 * dim)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
x: B, H*W, C
|
|
"""
|
|
H, W = self.input_resolution
|
|
B, L, C = x.shape
|
|
assert L == H * W, "input feature has wrong size"
|
|
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
|
|
|
x = x.view(B, H, W, C)
|
|
|
|
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
|
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
|
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
|
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
|
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
|
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
|
|
|
x = self.norm(x)
|
|
x = self.reduction(x)
|
|
|
|
return x
|
|
|
|
def extra_repr(self) -> str:
|
|
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
|
|
|
def flops(self):
|
|
H, W = self.input_resolution
|
|
flops = H * W * self.dim
|
|
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
|
return flops
|
|
|
|
|
|
class BasicLayer(nn.Module):
|
|
""" A basic Swin Transformer layer for one stage.
|
|
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
input_resolution (tuple[int]): Input resolution.
|
|
depth (int): Number of blocks.
|
|
num_heads (int): Number of attention heads.
|
|
window_size (int): Local window size.
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
|
drop (float, optional): Dropout rate. Default: 0.0
|
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
|
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
|
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
|
"""
|
|
|
|
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
|
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
|
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
|
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.input_resolution = input_resolution
|
|
self.depth = depth
|
|
self.use_checkpoint = use_checkpoint
|
|
|
|
# build blocks
|
|
self.blocks = nn.ModuleList([
|
|
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
|
num_heads=num_heads, window_size=window_size,
|
|
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
drop=drop, attn_drop=attn_drop,
|
|
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
|
norm_layer=norm_layer)
|
|
for i in range(depth)])
|
|
|
|
# patch merging layer
|
|
if downsample is not None:
|
|
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
|
else:
|
|
self.downsample = None
|
|
|
|
def forward(self, x):
|
|
for blk in self.blocks:
|
|
if not torch.jit.is_scripting() and self.use_checkpoint:
|
|
x = checkpoint.checkpoint(blk, x)
|
|
else:
|
|
x = blk(x)
|
|
if self.downsample is not None:
|
|
x = self.downsample(x)
|
|
return x
|
|
|
|
def extra_repr(self) -> str:
|
|
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
|
|
|
|
|
class SwinTransformer(nn.Module):
|
|
r""" Swin Transformer
|
|
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
|
|
https://arxiv.org/pdf/2103.14030
|
|
|
|
Args:
|
|
img_size (int | tuple(int)): Input image size. Default 224
|
|
patch_size (int | tuple(int)): Patch size. Default: 4
|
|
in_chans (int): Number of input image channels. Default: 3
|
|
num_classes (int): Number of classes for classification head. Default: 1000
|
|
embed_dim (int): Patch embedding dimension. Default: 96
|
|
depths (tuple(int)): Depth of each Swin Transformer layer.
|
|
num_heads (tuple(int)): Number of attention heads in different layers.
|
|
window_size (int): Window size. Default: 7
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
|
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
|
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
|
drop_rate (float): Dropout rate. Default: 0
|
|
attn_drop_rate (float): Attention dropout rate. Default: 0
|
|
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
|
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
|
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
|
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
|
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
|
"""
|
|
|
|
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
|
|
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
|
|
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
|
use_checkpoint=False, weight_init='', **kwargs):
|
|
super().__init__()
|
|
|
|
self.num_classes = num_classes
|
|
self.num_layers = len(depths)
|
|
self.embed_dim = embed_dim
|
|
self.ape = ape
|
|
self.patch_norm = patch_norm
|
|
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
|
|
self.mlp_ratio = mlp_ratio
|
|
|
|
# split image into non-overlapping patches
|
|
self.patch_embed = PatchEmbed(
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
|
norm_layer=norm_layer if self.patch_norm else None)
|
|
num_patches = self.patch_embed.num_patches
|
|
self.patch_grid = self.patch_embed.patch_grid
|
|
|
|
# absolute position embedding
|
|
if self.ape:
|
|
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
|
trunc_normal_(self.absolute_pos_embed, std=.02)
|
|
else:
|
|
self.absolute_pos_embed = None
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
# stochastic depth
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
|
|
|
# build layers
|
|
layers = []
|
|
for i_layer in range(self.num_layers):
|
|
layers += [BasicLayer(
|
|
dim=int(embed_dim * 2 ** i_layer),
|
|
input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)),
|
|
depth=depths[i_layer],
|
|
num_heads=num_heads[i_layer],
|
|
window_size=window_size,
|
|
mlp_ratio=self.mlp_ratio,
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
|
norm_layer=norm_layer,
|
|
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
|
use_checkpoint=use_checkpoint)
|
|
]
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
self.norm = norm_layer(self.num_features)
|
|
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
|
|
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
|
|
if weight_init.startswith('jax'):
|
|
for n, m in self.named_modules():
|
|
_init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
|
|
else:
|
|
self.apply(_init_vit_weights)
|
|
|
|
@torch.jit.ignore
|
|
def no_weight_decay(self):
|
|
return {'absolute_pos_embed'}
|
|
|
|
@torch.jit.ignore
|
|
def no_weight_decay_keywords(self):
|
|
return {'relative_position_bias_table'}
|
|
|
|
def forward_features(self, x):
|
|
x = self.patch_embed(x)
|
|
if self.absolute_pos_embed is not None:
|
|
x = x + self.absolute_pos_embed
|
|
x = self.pos_drop(x)
|
|
x = self.layers(x)
|
|
x = self.norm(x) # B L C
|
|
x = self.avgpool(x.transpose(1, 2)) # B C 1
|
|
x = torch.flatten(x, 1)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
x = self.forward_features(x)
|
|
x = self.head(x)
|
|
return x
|
|
|
|
|
|
def _create_swin_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
|
|
if default_cfg is None:
|
|
default_cfg = deepcopy(default_cfgs[variant])
|
|
overlay_external_default_cfg(default_cfg, kwargs)
|
|
default_num_classes = default_cfg['num_classes']
|
|
default_img_size = default_cfg['input_size'][-2:]
|
|
|
|
num_classes = kwargs.pop('num_classes', default_num_classes)
|
|
img_size = kwargs.pop('img_size', default_img_size)
|
|
if kwargs.get('features_only', None):
|
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
|
|
|
model = build_model_with_cfg(
|
|
SwinTransformer, variant, pretrained,
|
|
default_cfg=default_cfg,
|
|
img_size=img_size,
|
|
num_classes=num_classes,
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
**kwargs)
|
|
|
|
return model
|
|
|
|
|
|
|
|
@register_model
|
|
def swin_base_patch4_window12_384(pretrained=False, **kwargs):
|
|
""" Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k
|
|
"""
|
|
model_kwargs = dict(
|
|
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
|
return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
@register_model
|
|
def swin_base_patch4_window7_224(pretrained=False, **kwargs):
|
|
""" Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k
|
|
"""
|
|
model_kwargs = dict(
|
|
patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
|
return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
@register_model
|
|
def swin_large_patch4_window12_384(pretrained=False, **kwargs):
|
|
""" Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k
|
|
"""
|
|
model_kwargs = dict(
|
|
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
|
return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
@register_model
|
|
def swin_large_patch4_window7_224(pretrained=False, **kwargs):
|
|
""" Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k
|
|
"""
|
|
model_kwargs = dict(
|
|
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
|
return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
@register_model
|
|
def swin_small_patch4_window7_224(pretrained=False, **kwargs):
|
|
""" Swin-S @ 224x224, trained ImageNet-1k
|
|
"""
|
|
model_kwargs = dict(
|
|
patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
|
|
return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
@register_model
|
|
def swin_tiny_patch4_window7_224(pretrained=False, **kwargs):
|
|
""" Swin-T @ 224x224, trained ImageNet-1k
|
|
"""
|
|
model_kwargs = dict(
|
|
patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs)
|
|
return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
@register_model
|
|
def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):
|
|
""" Swin-B @ 384x384, trained ImageNet-22k
|
|
"""
|
|
model_kwargs = dict(
|
|
patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
|
return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
@register_model
|
|
def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):
|
|
""" Swin-B @ 224x224, trained ImageNet-22k
|
|
"""
|
|
model_kwargs = dict(
|
|
patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
|
|
return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
@register_model
|
|
def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):
|
|
""" Swin-L @ 384x384, trained ImageNet-22k
|
|
"""
|
|
model_kwargs = dict(
|
|
patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
|
return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
@register_model
|
|
def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
|
|
""" Swin-L @ 224x224, trained ImageNet-22k
|
|
"""
|
|
model_kwargs = dict(
|
|
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
|
return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs) |