|
|
@ -45,7 +45,7 @@ class SequentialWithSize(nn.Sequential):
|
|
|
|
return x, size
|
|
|
|
return x, size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
class ConvPosEncOld(nn.Module):
|
|
|
|
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
|
|
|
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
|
|
|
|
|
|
|
|
|
|
|
super(ConvPosEnc, self).__init__()
|
|
|
|
super(ConvPosEnc, self).__init__()
|
|
|
@ -79,10 +79,44 @@ class ConvPosEnc(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
super(ConvPosEnc, self).__init__()
|
|
|
|
|
|
|
|
self.proj = nn.Conv2d(dim,
|
|
|
|
|
|
|
|
dim,
|
|
|
|
|
|
|
|
to_2tuple(k),
|
|
|
|
|
|
|
|
to_2tuple(1),
|
|
|
|
|
|
|
|
to_2tuple(k // 2),
|
|
|
|
|
|
|
|
groups=dim)
|
|
|
|
|
|
|
|
self.normtype = normtype
|
|
|
|
|
|
|
|
self.norm = nn.Identity()
|
|
|
|
|
|
|
|
if self.normtype == 'batch':
|
|
|
|
|
|
|
|
self.norm = nn.BatchNorm2d(dim)
|
|
|
|
|
|
|
|
elif self.normtype == 'layer':
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(dim)
|
|
|
|
|
|
|
|
self.activation = nn.GELU() if act else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor):
|
|
|
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#feat = x.transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
|
|
|
feat = self.proj(feat)
|
|
|
|
|
|
|
|
if self.normtype == 'batch':
|
|
|
|
|
|
|
|
feat = self.norm(feat).flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
elif self.normtype == 'layer':
|
|
|
|
|
|
|
|
feat = self.norm(feat.flatten(2).transpose(1, 2))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
feat = feat.flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
x = x + self.activation(feat).transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# reason: dim in control sequence
|
|
|
|
# reason: dim in control sequence
|
|
|
|
# FIXME reimplement to allow tracing
|
|
|
|
# FIXME reimplement to allow tracing
|
|
|
|
@register_notrace_module
|
|
|
|
@register_notrace_module
|
|
|
|
class PatchEmbed(nn.Module):
|
|
|
|
class PatchEmbedOld(nn.Module):
|
|
|
|
""" Size-agnostic implementation of 2D image to patch embedding,
|
|
|
|
""" Size-agnostic implementation of 2D image to patch embedding,
|
|
|
|
allowing input size to be adjusted during model forward operation
|
|
|
|
allowing input size to be adjusted during model forward operation
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -141,6 +175,60 @@ class PatchEmbed(nn.Module):
|
|
|
|
x = self.norm(x)
|
|
|
|
x = self.norm(x)
|
|
|
|
return x, newsize
|
|
|
|
return x, newsize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_notrace_module
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
|
|
|
|
|
|
""" Size-agnostic implementation of 2D image to patch embedding,
|
|
|
|
|
|
|
|
allowing input size to be adjusted during model forward operation
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
patch_size=4,
|
|
|
|
|
|
|
|
in_chans=3,
|
|
|
|
|
|
|
|
embed_dim=96,
|
|
|
|
|
|
|
|
overlapped=False):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
patch_size = to_2tuple(patch_size)
|
|
|
|
|
|
|
|
self.patch_size = patch_size
|
|
|
|
|
|
|
|
self.in_chans = in_chans
|
|
|
|
|
|
|
|
self.embed_dim = embed_dim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if patch_size[0] == 4:
|
|
|
|
|
|
|
|
self.proj = nn.Conv2d(
|
|
|
|
|
|
|
|
in_chans,
|
|
|
|
|
|
|
|
embed_dim,
|
|
|
|
|
|
|
|
kernel_size=(7, 7),
|
|
|
|
|
|
|
|
stride=patch_size,
|
|
|
|
|
|
|
|
padding=(3, 3))
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(embed_dim)
|
|
|
|
|
|
|
|
if patch_size[0] == 2:
|
|
|
|
|
|
|
|
kernel = 3 if overlapped else 2
|
|
|
|
|
|
|
|
pad = 1 if overlapped else 0
|
|
|
|
|
|
|
|
self.proj = nn.Conv2d(
|
|
|
|
|
|
|
|
in_chans,
|
|
|
|
|
|
|
|
embed_dim,
|
|
|
|
|
|
|
|
kernel_size=to_2tuple(kernel),
|
|
|
|
|
|
|
|
stride=patch_size,
|
|
|
|
|
|
|
|
padding=to_2tuple(pad))
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(in_chans)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor):
|
|
|
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
|
|
|
if self.norm.normalized_shape[0] == self.in_chans:
|
|
|
|
|
|
|
|
x = self.norm(x.flatten(2).transpose(1, 2)).transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if W % self.patch_size[1] != 0:
|
|
|
|
|
|
|
|
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
|
|
|
|
|
|
|
if H % self.patch_size[0] != 0:
|
|
|
|
|
|
|
|
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
|
|
|
#x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
if self.norm.normalized_shape[0] == self.embed_dim:
|
|
|
|
|
|
|
|
x = self.norm(x.flatten(2).transpose(1, 2)).transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelAttention(nn.Module):
|
|
|
|
class ChannelAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
@ -167,8 +255,36 @@ class ChannelAttention(nn.Module):
|
|
|
|
x = self.proj(x)
|
|
|
|
x = self.proj(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelAttentionNew(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
|
|
|
head_dim = dim // num_heads
|
|
|
|
|
|
|
|
self.scale = head_dim ** -0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
|
|
|
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor):
|
|
|
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
k = k * self.scale
|
|
|
|
|
|
|
|
attention = k.transpose(-1, -2) @ v
|
|
|
|
|
|
|
|
attention = attention.softmax(dim=-1)
|
|
|
|
|
|
|
|
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
|
|
|
|
|
|
|
|
x = x.transpose(1, 2).reshape(B, N, C)
|
|
|
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
|
|
|
x = x.transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelBlockOld(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
|
|
@ -204,6 +320,46 @@ class ChannelBlock(nn.Module):
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
return x, size
|
|
|
|
return x, size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False,
|
|
|
|
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
|
|
|
|
|
|
|
ffn=True, cpe_act=False):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
|
|
|
|
self.ffn = ffn
|
|
|
|
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
|
|
|
|
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
|
|
|
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.ffn:
|
|
|
|
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
|
|
|
|
self.mlp = Mlp(
|
|
|
|
|
|
|
|
in_features=dim,
|
|
|
|
|
|
|
|
hidden_features=mlp_hidden_dim,
|
|
|
|
|
|
|
|
act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.cpe1(x).flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cur = self.norm1(x)
|
|
|
|
|
|
|
|
cur = self.attn(cur)
|
|
|
|
|
|
|
|
x = x + self.drop_path(cur)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
if self.ffn:
|
|
|
|
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def window_partition(x : Tensor, window_size: int):
|
|
|
|
def window_partition(x : Tensor, window_size: int):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -319,6 +475,91 @@ class SpatialBlock(nn.Module):
|
|
|
|
act_layer=act_layer)
|
|
|
|
act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shortcut = self.cpe1(x, size).flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
x = self.norm1(shortcut)
|
|
|
|
|
|
|
|
x = x.view(B, H, W, C)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pad_l = pad_t = 0
|
|
|
|
|
|
|
|
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
|
|
|
|
|
|
|
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
|
|
|
|
|
|
|
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
|
|
|
|
|
|
|
_, Hp, Wp, _ = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_windows = window_partition(x, self.window_size)
|
|
|
|
|
|
|
|
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# W-MSA/SW-MSA
|
|
|
|
|
|
|
|
attn_windows = self.attn(x_windows)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# merge windows
|
|
|
|
|
|
|
|
attn_windows = attn_windows.view(-1,
|
|
|
|
|
|
|
|
self.window_size,
|
|
|
|
|
|
|
|
self.window_size,
|
|
|
|
|
|
|
|
C)
|
|
|
|
|
|
|
|
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#if pad_r > 0 or pad_b > 0:
|
|
|
|
|
|
|
|
x = x[:, :H, :W, :].contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.view(B, H * W, C)
|
|
|
|
|
|
|
|
x = shortcut + self.drop_path(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.cpe2(x, size)
|
|
|
|
|
|
|
|
if self.ffn:
|
|
|
|
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return x, size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpatialBlockOld(nn.Module):
|
|
|
|
|
|
|
|
r""" Windows Block.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
dim (int): Number of input channels.
|
|
|
|
|
|
|
|
num_heads (int): Number of attention heads.
|
|
|
|
|
|
|
|
window_size (int): Window size.
|
|
|
|
|
|
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
|
|
|
|
|
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
|
|
|
|
|
|
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
|
|
|
|
|
|
|
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
|
|
|
|
|
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads, window_size=7,
|
|
|
|
|
|
|
|
mlp_ratio=4., qkv_bias=True, drop_path=0.,
|
|
|
|
|
|
|
|
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
|
|
|
|
|
|
|
ffn=True, cpe_act=False):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.dim = dim
|
|
|
|
|
|
|
|
self.ffn = ffn
|
|
|
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
|
|
|
self.window_size = window_size
|
|
|
|
|
|
|
|
self.mlp_ratio = mlp_ratio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
|
|
|
|
self.attn = WindowAttention(
|
|
|
|
|
|
|
|
dim,
|
|
|
|
|
|
|
|
window_size=to_2tuple(self.window_size),
|
|
|
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
|
|
|
qkv_bias=qkv_bias)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.ffn:
|
|
|
|
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
|
|
|
|
self.mlp = Mlp(
|
|
|
|
|
|
|
|
in_features=dim,
|
|
|
|
|
|
|
|
hidden_features=mlp_hidden_dim,
|
|
|
|
|
|
|
|
act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
|
|
|
|
|
|
|
H, W = size
|
|
|
|
H, W = size
|
|
|
@ -424,9 +665,9 @@ class DaViTStage(nn.Module):
|
|
|
|
cpe_act=cpe_act
|
|
|
|
cpe_act=cpe_act
|
|
|
|
))
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
|
|
stage_blocks.append(SequentialWithSize(*dual_attention_block))
|
|
|
|
stage_blocks.append(nn.Sequential(*dual_attention_block))
|
|
|
|
|
|
|
|
|
|
|
|
self.blocks = SequentialWithSize(*stage_blocks)
|
|
|
|
self.blocks = nn.Sequential(*stage_blocks)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
x, size = self.patch_embed(x, size)
|
|
|
|
x, size = self.patch_embed(x, size)
|
|
|
@ -519,7 +760,7 @@ class DaViT(nn.Module):
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.stages = SequentialWithSize(*stages)
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
|
|
|
|
|
|
|
|
self.norms = norm_layer(self.num_features)
|
|
|
|
self.norms = norm_layer(self.num_features)
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|