|
|
|
@ -11,9 +11,10 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
|
|
|
|
# Copyright (c) 2022 Mingyu Ding
|
|
|
|
|
# All rights reserved.
|
|
|
|
|
# This source code is licensed under the MIT license
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
import itertools
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -21,9 +22,8 @@ import torch.nn.functional as F
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp # ClassifierHead
|
|
|
|
|
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
|
|
|
|
|
from ._builder import build_model_with_cfg
|
|
|
|
|
from ._features import FeatureInfo
|
|
|
|
|
from ._features_fx import register_notrace_function
|
|
|
|
|
from ._manipulate import checkpoint_seq
|
|
|
|
|
from ._pretrained import generate_default_cfgs
|
|
|
|
@ -33,89 +33,83 @@ __all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
|
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim: int, k: int = 3, act: bool = False):
|
|
|
|
|
super(ConvPosEnc, self).__init__()
|
|
|
|
|
self.proj = nn.Conv2d(dim,
|
|
|
|
|
dim,
|
|
|
|
|
to_2tuple(k),
|
|
|
|
|
to_2tuple(1),
|
|
|
|
|
to_2tuple(k // 2),
|
|
|
|
|
groups=dim)
|
|
|
|
|
self.normtype = normtype
|
|
|
|
|
self.norm = nn.Identity()
|
|
|
|
|
if self.normtype == 'batch':
|
|
|
|
|
self.norm = nn.BatchNorm2d(dim)
|
|
|
|
|
elif self.normtype == 'layer':
|
|
|
|
|
self.norm = nn.LayerNorm(dim)
|
|
|
|
|
self.activation = nn.GELU() if act else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor):
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
|
|
|
|
|
self.act = nn.GELU() if act else nn.Identity()
|
|
|
|
|
|
|
|
|
|
#feat = x.transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
def forward(self, x: Tensor):
|
|
|
|
|
feat = self.proj(x)
|
|
|
|
|
if self.normtype == 'batch':
|
|
|
|
|
feat = self.norm(feat).flatten(2).transpose(1, 2)
|
|
|
|
|
elif self.normtype == 'layer':
|
|
|
|
|
feat = self.norm(feat.flatten(2).transpose(1, 2))
|
|
|
|
|
else:
|
|
|
|
|
feat = feat.flatten(2).transpose(1, 2)
|
|
|
|
|
x = x + self.activation(feat).transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
x = x + self.act(feat)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
|
|
|
class Stem(nn.Module):
|
|
|
|
|
""" Size-agnostic implementation of 2D image to patch embedding,
|
|
|
|
|
allowing input size to be adjusted during model forward operation
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
patch_size=4,
|
|
|
|
|
in_chans=3,
|
|
|
|
|
embed_dim=96,
|
|
|
|
|
overlapped=False):
|
|
|
|
|
in_chs=3,
|
|
|
|
|
out_chs=96,
|
|
|
|
|
stride=4,
|
|
|
|
|
norm_layer=LayerNorm2d,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
patch_size = to_2tuple(patch_size)
|
|
|
|
|
self.patch_size = patch_size
|
|
|
|
|
self.in_chans = in_chans
|
|
|
|
|
self.embed_dim = embed_dim
|
|
|
|
|
|
|
|
|
|
if patch_size[0] == 4:
|
|
|
|
|
self.proj = nn.Conv2d(
|
|
|
|
|
in_chans,
|
|
|
|
|
embed_dim,
|
|
|
|
|
kernel_size=(7, 7),
|
|
|
|
|
stride=patch_size,
|
|
|
|
|
padding=(3, 3))
|
|
|
|
|
self.norm = nn.LayerNorm(embed_dim)
|
|
|
|
|
if patch_size[0] == 2:
|
|
|
|
|
kernel = 3 if overlapped else 2
|
|
|
|
|
pad = 1 if overlapped else 0
|
|
|
|
|
self.proj = nn.Conv2d(
|
|
|
|
|
in_chans,
|
|
|
|
|
embed_dim,
|
|
|
|
|
kernel_size=to_2tuple(kernel),
|
|
|
|
|
stride=patch_size,
|
|
|
|
|
padding=to_2tuple(pad))
|
|
|
|
|
self.norm = nn.LayerNorm(in_chans)
|
|
|
|
|
|
|
|
|
|
stride = to_2tuple(stride)
|
|
|
|
|
self.stride = stride
|
|
|
|
|
self.in_chs = in_chs
|
|
|
|
|
self.out_chs = out_chs
|
|
|
|
|
assert stride[0] == 4 # only setup for stride==4
|
|
|
|
|
self.conv = nn.Conv2d(
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
kernel_size=7,
|
|
|
|
|
stride=stride,
|
|
|
|
|
padding=3,
|
|
|
|
|
)
|
|
|
|
|
self.norm = norm_layer(out_chs)
|
|
|
|
|
|
|
|
|
|
def forward(self, x: Tensor):
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
if self.norm.normalized_shape[0] == self.in_chans:
|
|
|
|
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
|
|
|
|
|
x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
|
|
|
|
|
x = self.conv(x)
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
x = F.pad(x, (0, (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]))
|
|
|
|
|
x = F.pad(x, (0, 0, 0, (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]))
|
|
|
|
|
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
class Downsample(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
norm_layer=LayerNorm2d,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.in_chs = in_chs
|
|
|
|
|
self.out_chs = out_chs
|
|
|
|
|
|
|
|
|
|
if self.norm.normalized_shape[0] == self.embed_dim:
|
|
|
|
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
self.norm = norm_layer(in_chs)
|
|
|
|
|
self.conv = nn.Conv2d(
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
kernel_size=2,
|
|
|
|
|
stride=2,
|
|
|
|
|
padding=0,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, x: Tensor):
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
x = F.pad(x, (0, (2 - W % 2) % 2))
|
|
|
|
|
x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2))
|
|
|
|
|
x = self.conv(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False):
|
|
|
|
@ -131,7 +125,7 @@ class ChannelAttention(nn.Module):
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
|
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
|
|
|
q, k, v = qkv.unbind(0)
|
|
|
|
|
|
|
|
|
|
k = k * self.scale
|
|
|
|
|
attention = k.transpose(-1, -2) @ v
|
|
|
|
@ -144,46 +138,60 @@ class ChannelAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
class ChannelBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
|
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
|
|
|
|
ffn=True, cpe_act=False):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
dim,
|
|
|
|
|
num_heads,
|
|
|
|
|
mlp_ratio=4.,
|
|
|
|
|
qkv_bias=False,
|
|
|
|
|
drop_path=0.,
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
norm_layer=nn.LayerNorm,
|
|
|
|
|
ffn=True,
|
|
|
|
|
cpe_act=False,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
|
self.ffn = ffn
|
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
|
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
|
|
|
|
|
|
if self.ffn:
|
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
|
self.mlp = Mlp(
|
|
|
|
|
in_features=dim,
|
|
|
|
|
hidden_features=mlp_hidden_dim,
|
|
|
|
|
act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
hidden_features=int(dim * mlp_ratio),
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
)
|
|
|
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
else:
|
|
|
|
|
self.norm2 = None
|
|
|
|
|
self.mlp = None
|
|
|
|
|
self.drop_path2 = None
|
|
|
|
|
|
|
|
|
|
def forward(self, x: Tensor):
|
|
|
|
|
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
|
|
|
|
|
x = self.cpe1(x).flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
cur = self.norm1(x)
|
|
|
|
|
cur = self.attn(cur)
|
|
|
|
|
x = x + self.drop_path(cur)
|
|
|
|
|
x = x + self.drop_path1(cur)
|
|
|
|
|
|
|
|
|
|
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)
|
|
|
|
|
if self.ffn:
|
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
|
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
|
|
|
|
|
|
|
|
|
|
if self.mlp is not None:
|
|
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
|
|
|
|
x = x.transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def window_partition(x : Tensor, window_size: int):
|
|
|
|
|
|
|
|
|
|
def window_partition(x: Tensor, window_size: Tuple[int, int]):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
x: (B, H, W, C)
|
|
|
|
@ -192,12 +200,13 @@ def window_partition(x : Tensor, window_size: int):
|
|
|
|
|
windows: (num_windows*B, window_size, window_size, C)
|
|
|
|
|
"""
|
|
|
|
|
B, H, W, C = x.shape
|
|
|
|
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
|
|
|
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
|
|
|
|
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
|
|
|
|
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
|
|
|
|
|
return windows
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_notrace_function # reason: int argument is a Proxy
|
|
|
|
|
def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
|
|
|
|
|
def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
windows: (num_windows*B, window_size, window_size, C)
|
|
|
|
@ -207,9 +216,8 @@ def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
|
|
|
|
|
Returns:
|
|
|
|
|
x: (B, H, W, C)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
|
|
|
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
|
|
|
|
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
|
|
|
|
|
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
|
|
|
|
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
@ -225,7 +233,6 @@ class WindowAttention(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True):
|
|
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.dim = dim
|
|
|
|
|
self.window_size = window_size
|
|
|
|
@ -242,7 +249,7 @@ class WindowAttention(nn.Module):
|
|
|
|
|
B_, N, C = x.shape
|
|
|
|
|
|
|
|
|
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
|
|
|
q, k, v = qkv.unbind(0)
|
|
|
|
|
|
|
|
|
|
q = q * self.scale
|
|
|
|
|
attn = (q @ k.transpose(-2, -1))
|
|
|
|
@ -266,74 +273,86 @@ class SpatialBlock(nn.Module):
|
|
|
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads, window_size=7,
|
|
|
|
|
mlp_ratio=4., qkv_bias=True, drop_path=0.,
|
|
|
|
|
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
|
|
|
|
ffn=True, cpe_act=False):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
dim,
|
|
|
|
|
num_heads,
|
|
|
|
|
window_size=7,
|
|
|
|
|
mlp_ratio=4.,
|
|
|
|
|
qkv_bias=True,
|
|
|
|
|
drop_path=0.,
|
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
|
norm_layer=nn.LayerNorm,
|
|
|
|
|
ffn=True,
|
|
|
|
|
cpe_act=False,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.dim = dim
|
|
|
|
|
self.ffn = ffn
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.window_size = window_size
|
|
|
|
|
self.window_size = to_2tuple(window_size)
|
|
|
|
|
self.mlp_ratio = mlp_ratio
|
|
|
|
|
|
|
|
|
|
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
|
self.attn = WindowAttention(
|
|
|
|
|
dim,
|
|
|
|
|
window_size=to_2tuple(self.window_size),
|
|
|
|
|
self.window_size,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
qkv_bias=qkv_bias)
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
)
|
|
|
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
|
|
|
|
|
|
if self.ffn:
|
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
|
self.mlp = Mlp(
|
|
|
|
|
in_features=dim,
|
|
|
|
|
hidden_features=mlp_hidden_dim,
|
|
|
|
|
act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
act_layer=act_layer,
|
|
|
|
|
)
|
|
|
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
else:
|
|
|
|
|
self.norm2 = None
|
|
|
|
|
self.mlp = None
|
|
|
|
|
self.drop_path1 = None
|
|
|
|
|
|
|
|
|
|
def forward(self, x: Tensor):
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
x = self.norm1(shortcut)
|
|
|
|
|
x = x.view(B, H, W, C)
|
|
|
|
|
|
|
|
|
|
pad_l = pad_t = 0
|
|
|
|
|
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
|
|
|
|
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
|
|
|
|
pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
|
|
|
|
|
pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
|
|
|
|
|
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
|
|
|
|
_, Hp, Wp, _ = x.shape
|
|
|
|
|
|
|
|
|
|
x_windows = window_partition(x, self.window_size)
|
|
|
|
|
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
|
|
|
|
|
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
|
|
|
|
|
|
|
|
|
|
# W-MSA/SW-MSA
|
|
|
|
|
attn_windows = self.attn(x_windows)
|
|
|
|
|
|
|
|
|
|
# merge windows
|
|
|
|
|
attn_windows = attn_windows.view(-1,
|
|
|
|
|
self.window_size,
|
|
|
|
|
self.window_size,
|
|
|
|
|
C)
|
|
|
|
|
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
|
|
|
|
|
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
|
|
|
|
|
|
|
|
|
|
# if pad_r > 0 or pad_b > 0:
|
|
|
|
|
x = x[:, :H, :W, :].contiguous()
|
|
|
|
|
|
|
|
|
|
x = x.view(B, H * W, C)
|
|
|
|
|
x = shortcut + self.drop_path(x)
|
|
|
|
|
x = shortcut + self.drop_path1(x)
|
|
|
|
|
|
|
|
|
|
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)
|
|
|
|
|
if self.ffn:
|
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
|
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
|
|
|
|
|
|
|
|
|
|
if self.mlp is not None:
|
|
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
|
|
|
|
x = x.transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
@ -345,15 +364,15 @@ class DaViTStage(nn.Module):
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
depth=1,
|
|
|
|
|
patch_size = 4,
|
|
|
|
|
overlapped_patch = False,
|
|
|
|
|
attention_types = ('spatial', 'channel'),
|
|
|
|
|
downsample=True,
|
|
|
|
|
attn_types=('spatial', 'channel'),
|
|
|
|
|
num_heads=3,
|
|
|
|
|
window_size=7,
|
|
|
|
|
mlp_ratio=4,
|
|
|
|
|
qkv_bias=True,
|
|
|
|
|
drop_path_rates=(0, 0),
|
|
|
|
|
norm_layer = nn.LayerNorm,
|
|
|
|
|
norm_layer=LayerNorm2d,
|
|
|
|
|
norm_layer_cl=nn.LayerNorm,
|
|
|
|
|
ffn=True,
|
|
|
|
|
cpe_act=False
|
|
|
|
|
):
|
|
|
|
@ -361,13 +380,12 @@ class DaViTStage(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
# patch embedding layer at the beginning of each stage
|
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
|
|
patch_size=patch_size,
|
|
|
|
|
in_chans=in_chs,
|
|
|
|
|
embed_dim=out_chs,
|
|
|
|
|
overlapped=overlapped_patch
|
|
|
|
|
)
|
|
|
|
|
# downsample embedding layer at the beginning of each stage
|
|
|
|
|
if downsample:
|
|
|
|
|
self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer)
|
|
|
|
|
else:
|
|
|
|
|
self.downsample = nn.Identity()
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
repeating alternating attention blocks in each stage
|
|
|
|
|
default: (spatial -> channel) x depth
|
|
|
|
@ -377,36 +395,32 @@ class DaViTStage(nn.Module):
|
|
|
|
|
'''
|
|
|
|
|
stage_blocks = []
|
|
|
|
|
for block_idx in range(depth):
|
|
|
|
|
|
|
|
|
|
dual_attention_block = []
|
|
|
|
|
|
|
|
|
|
for attention_id, attention_type in enumerate(attention_types):
|
|
|
|
|
if attention_type == 'spatial':
|
|
|
|
|
for attn_idx, attn_type in enumerate(attn_types):
|
|
|
|
|
if attn_type == 'spatial':
|
|
|
|
|
dual_attention_block.append(SpatialBlock(
|
|
|
|
|
dim=out_chs,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
|
|
|
|
norm_layer=norm_layer,
|
|
|
|
|
drop_path=drop_path_rates[block_idx],
|
|
|
|
|
norm_layer=norm_layer_cl,
|
|
|
|
|
ffn=ffn,
|
|
|
|
|
cpe_act=cpe_act,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
))
|
|
|
|
|
elif attention_type == 'channel':
|
|
|
|
|
elif attn_type == 'channel':
|
|
|
|
|
dual_attention_block.append(ChannelBlock(
|
|
|
|
|
dim=out_chs,
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
|
|
|
|
norm_layer=norm_layer,
|
|
|
|
|
drop_path=drop_path_rates[block_idx],
|
|
|
|
|
norm_layer=norm_layer_cl,
|
|
|
|
|
ffn=ffn,
|
|
|
|
|
cpe_act=cpe_act
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
stage_blocks.append(nn.Sequential(*dual_attention_block))
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.Sequential(*stage_blocks)
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
@ -414,7 +428,7 @@ class DaViTStage(nn.Module):
|
|
|
|
|
self.grad_checkpointing = enable
|
|
|
|
|
|
|
|
|
|
def forward(self, x: Tensor):
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
x = self.downsample(x)
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
x = checkpoint_seq(self.blocks, x)
|
|
|
|
|
else:
|
|
|
|
@ -431,7 +445,6 @@ class DaViT(nn.Module):
|
|
|
|
|
in_chans (int): Number of input image channels. Default: 3
|
|
|
|
|
num_classes (int): Number of classes for classification head. Default: 1000
|
|
|
|
|
depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
|
|
|
|
|
patch_size (int | tuple(int)): Patch size. Default: 4
|
|
|
|
|
embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
|
|
|
|
|
num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
|
|
|
|
|
window_size (int): Window size. Default: 7
|
|
|
|
@ -445,69 +458,61 @@ class DaViT(nn.Module):
|
|
|
|
|
self,
|
|
|
|
|
in_chans=3,
|
|
|
|
|
depths=(1, 1, 3, 1),
|
|
|
|
|
patch_size=4,
|
|
|
|
|
embed_dims=(96, 192, 384, 768),
|
|
|
|
|
num_heads=(3, 6, 12, 24),
|
|
|
|
|
window_size=7,
|
|
|
|
|
mlp_ratio=4.,
|
|
|
|
|
mlp_ratio=4,
|
|
|
|
|
qkv_bias=True,
|
|
|
|
|
drop_path_rate=0.1,
|
|
|
|
|
norm_layer=nn.LayerNorm,
|
|
|
|
|
attention_types=('spatial', 'channel'),
|
|
|
|
|
norm_layer='layernorm2d',
|
|
|
|
|
norm_layer_cl='layernorm',
|
|
|
|
|
norm_eps=1e-5,
|
|
|
|
|
attn_types=('spatial', 'channel'),
|
|
|
|
|
ffn=True,
|
|
|
|
|
overlapped_patch=False,
|
|
|
|
|
cpe_act=False,
|
|
|
|
|
drop_rate=0.,
|
|
|
|
|
attn_drop_rate=0.,
|
|
|
|
|
drop_path_rate=0.,
|
|
|
|
|
num_classes=1000,
|
|
|
|
|
global_pool='avg',
|
|
|
|
|
head_norm_first=False,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
architecture = [[index] * item for index, item in enumerate(depths)]
|
|
|
|
|
self.architecture = architecture
|
|
|
|
|
self.embed_dims = embed_dims
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.num_stages = len(self.embed_dims)
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, len(attention_types) * len(list(itertools.chain(*self.architecture))))]
|
|
|
|
|
assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1)
|
|
|
|
|
|
|
|
|
|
num_stages = len(embed_dims)
|
|
|
|
|
assert num_stages == len(num_heads) == len(depths)
|
|
|
|
|
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
|
|
|
|
|
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.num_features = embed_dims[-1]
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
self.patch_embed = None
|
|
|
|
|
stages = []
|
|
|
|
|
|
|
|
|
|
for stage_id in range(self.num_stages):
|
|
|
|
|
stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
|
|
|
|
|
self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer)
|
|
|
|
|
in_chs = embed_dims[0]
|
|
|
|
|
|
|
|
|
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
|
|
|
|
stages = []
|
|
|
|
|
for stage_idx in range(num_stages):
|
|
|
|
|
out_chs = embed_dims[stage_idx]
|
|
|
|
|
stage = DaViTStage(
|
|
|
|
|
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
|
|
|
|
|
embed_dims[stage_id],
|
|
|
|
|
depth = depths[stage_id],
|
|
|
|
|
patch_size = patch_size if stage_id == 0 else 2,
|
|
|
|
|
overlapped_patch = overlapped_patch,
|
|
|
|
|
attention_types = attention_types,
|
|
|
|
|
num_heads = num_heads[stage_id],
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
depth=depths[stage_idx],
|
|
|
|
|
downsample=stage_idx > 0,
|
|
|
|
|
attn_types=attn_types,
|
|
|
|
|
num_heads=num_heads[stage_idx],
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
|
drop_path_rates = stage_drop_rates,
|
|
|
|
|
norm_layer = nn.LayerNorm,
|
|
|
|
|
drop_path_rates=dpr[stage_idx],
|
|
|
|
|
norm_layer=norm_layer,
|
|
|
|
|
norm_layer_cl=norm_layer_cl,
|
|
|
|
|
ffn=ffn,
|
|
|
|
|
cpe_act = cpe_act
|
|
|
|
|
cpe_act=cpe_act,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if stage_id == 0:
|
|
|
|
|
self.patch_embed = stage.patch_embed
|
|
|
|
|
stage.patch_embed = nn.Identity()
|
|
|
|
|
|
|
|
|
|
in_chs = out_chs
|
|
|
|
|
stages.append(stage)
|
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
|
|
|
|
|
self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')]
|
|
|
|
|
|
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
|
|
|
|
|
@ -529,9 +534,6 @@ class DaViT(nn.Module):
|
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def set_grad_checkpointing(self, enable=True):
|
|
|
|
@ -550,17 +552,17 @@ class DaViT(nn.Module):
|
|
|
|
|
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
x = self.stem(x)
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
x = checkpoint_seq(self.stages, x)
|
|
|
|
|
else:
|
|
|
|
|
x = self.stages(x)
|
|
|
|
|
x = self.norm_pre(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
x = self.norm_pre(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
|
x = self.head.global_pool(x)
|
|
|
|
|
x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
x = self.head.norm(x)
|
|
|
|
|
x = self.head.flatten(x)
|
|
|
|
|
x = self.head.drop(x)
|
|
|
|
|
return x if pre_logits else self.head.fc(x)
|
|
|
|
@ -573,7 +575,7 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|
|
if 'head' in state_dict:
|
|
|
|
|
if 'head.fc.weight' in state_dict:
|
|
|
|
|
return state_dict # non-MSFT checkpoint
|
|
|
|
|
|
|
|
|
|
if 'state_dict' in state_dict:
|
|
|
|
@ -582,10 +584,10 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
import re
|
|
|
|
|
out_dict = {}
|
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
|
|
|
|
|
|
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
|
|
|
|
|
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k)
|
|
|
|
|
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
|
|
|
|
|
k = k.replace('stages.0.patch_embed', 'patch_embed')
|
|
|
|
|
k = k.replace('downsample.proj', 'downsample.conv')
|
|
|
|
|
k = k.replace('stages.0.downsample', 'stem')
|
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|
|
k = k.replace('norms.', 'head.norm.')
|
|
|
|
|
k = k.replace('cpe.0', 'cpe1')
|
|
|
|
@ -595,7 +597,6 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
|
|
|
|
|
|
@ -610,67 +611,69 @@ def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url,
|
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
|
'crop_pct': 0.850, 'interpolation': 'bicubic',
|
|
|
|
|
'crop_pct': 0.95, 'interpolation': 'bicubic',
|
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
|
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
|
|
|
|
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
|
|
|
|
**kwargs
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO contact authors to get larger pretrained models
|
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
|
# official microsoft weights from https://github.com/dingmyu/davit
|
|
|
|
|
'davit_tiny.msft_in1k': _cfg(
|
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"),
|
|
|
|
|
hf_hub_id='timm/'),
|
|
|
|
|
'davit_small.msft_in1k': _cfg(
|
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
|
|
|
|
|
hf_hub_id='timm/'),
|
|
|
|
|
'davit_base.msft_in1k': _cfg(
|
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
|
|
|
|
|
hf_hub_id='timm/'),
|
|
|
|
|
'davit_large': _cfg(),
|
|
|
|
|
'davit_huge': _cfg(),
|
|
|
|
|
'davit_giant': _cfg(),
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def davit_tiny(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768),
|
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
|
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def davit_small(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768),
|
|
|
|
|
num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
|
|
|
|
|
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def davit_base(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024),
|
|
|
|
|
num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
|
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def davit_large(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),
|
|
|
|
|
num_heads=(6, 12, 24, 48), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48), **kwargs)
|
|
|
|
|
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def davit_huge(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048),
|
|
|
|
|
num_heads=(8, 16, 32, 64), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64), **kwargs)
|
|
|
|
|
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def davit_giant(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
|
|
|
|
|
num_heads=(12, 24, 48, 96), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96), **kwargs)
|
|
|
|
|
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
|
|
|
|
|