Update davit.py

pull/1630/head
Fredo Guan 2 years ago
parent 205b24c69a
commit 2f6f103c5b

@ -341,88 +341,6 @@ class SpatialBlock(nn.Module):
return x
class SpatialBlockOld(nn.Module):
r""" Windows Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, num_heads, window_size=7,
mlp_ratio=4., qkv_bias=True, drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm,
ffn=True, cpe_act=False):
super().__init__()
self.dim = dim
self.ffn = ffn
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer)
def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size
B, L, C = x.shape
shortcut = self.cpe1(x, size)
x = self.norm1(shortcut)
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows)
# merge windows
attn_windows = attn_windows.view(-1,
self.window_size,
self.window_size,
C)
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
#if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
x = self.cpe2(x, size)
if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, size
class DaViTStage(nn.Module):
def __init__(

Loading…
Cancel
Save