|
|
|
@ -341,88 +341,6 @@ class SpatialBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class SpatialBlockOld(nn.Module):
|
|
|
|
|
r""" Windows Block.
|
|
|
|
|
Args:
|
|
|
|
|
dim (int): Number of input channels.
|
|
|
|
|
num_heads (int): Number of attention heads.
|
|
|
|
|
window_size (int): Window size.
|
|
|
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
|
|
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
|
|
|
|
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
|
|
|
|
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
|
|
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads, window_size=7,
|
|
|
|
|
mlp_ratio=4., qkv_bias=True, drop_path=0.,
|
|
|
|
|
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
|
|
|
|
ffn=True, cpe_act=False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.dim = dim
|
|
|
|
|
self.ffn = ffn
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.window_size = window_size
|
|
|
|
|
self.mlp_ratio = mlp_ratio
|
|
|
|
|
|
|
|
|
|
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
|
self.attn = WindowAttention(
|
|
|
|
|
dim,
|
|
|
|
|
window_size=to_2tuple(self.window_size),
|
|
|
|
|
num_heads=num_heads,
|
|
|
|
|
qkv_bias=qkv_bias)
|
|
|
|
|
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
|
|
|
|
|
|
|
|
|
|
if self.ffn:
|
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
|
self.mlp = Mlp(
|
|
|
|
|
in_features=dim,
|
|
|
|
|
hidden_features=mlp_hidden_dim,
|
|
|
|
|
act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
|
|
|
|
|
H, W = size
|
|
|
|
|
B, L, C = x.shape
|
|
|
|
|
|
|
|
|
|
shortcut = self.cpe1(x, size)
|
|
|
|
|
x = self.norm1(shortcut)
|
|
|
|
|
x = x.view(B, H, W, C)
|
|
|
|
|
|
|
|
|
|
pad_l = pad_t = 0
|
|
|
|
|
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
|
|
|
|
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
|
|
|
|
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
|
|
|
|
_, Hp, Wp, _ = x.shape
|
|
|
|
|
|
|
|
|
|
x_windows = window_partition(x, self.window_size)
|
|
|
|
|
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
|
|
|
|
|
|
|
|
|
|
# W-MSA/SW-MSA
|
|
|
|
|
attn_windows = self.attn(x_windows)
|
|
|
|
|
|
|
|
|
|
# merge windows
|
|
|
|
|
attn_windows = attn_windows.view(-1,
|
|
|
|
|
self.window_size,
|
|
|
|
|
self.window_size,
|
|
|
|
|
C)
|
|
|
|
|
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
|
|
|
|
|
|
|
|
|
|
#if pad_r > 0 or pad_b > 0:
|
|
|
|
|
x = x[:, :H, :W, :].contiguous()
|
|
|
|
|
|
|
|
|
|
x = x.view(B, H * W, C)
|
|
|
|
|
x = shortcut + self.drop_path(x)
|
|
|
|
|
|
|
|
|
|
x = self.cpe2(x, size)
|
|
|
|
|
if self.ffn:
|
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
|
return x, size
|
|
|
|
|
|
|
|
|
|
class DaViTStage(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|