|
|
|
@ -52,7 +52,7 @@ class ConvPosEnc(nn.Module):
|
|
|
|
|
self.norm = nn.LayerNorm(dim)
|
|
|
|
|
self.activation = nn.GELU() if act else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
H, W = size
|
|
|
|
|
assert N == H * W
|
|
|
|
@ -104,9 +104,9 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
self.norm = nn.LayerNorm(in_chans)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
H, W = size
|
|
|
|
|
dim = len(x.shape)
|
|
|
|
|
dim = x.dim()
|
|
|
|
|
if dim == 3:
|
|
|
|
|
B, HW, C = x.shape
|
|
|
|
|
x = self.norm(x)
|
|
|
|
@ -140,7 +140,7 @@ class ChannelAttention(nn.Module):
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
def forward(self, x : Tensor):
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
|
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
@ -178,7 +178,7 @@ class ChannelBlock(nn.Module):
|
|
|
|
|
act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
x = self.cpe[0](x, size)
|
|
|
|
|
cur = self.norm1(x)
|
|
|
|
|
cur = self.attn(cur)
|
|
|
|
@ -190,7 +190,7 @@ class ChannelBlock(nn.Module):
|
|
|
|
|
return x, size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def window_partition(x, window_size: int):
|
|
|
|
|
def window_partition(x : Tensor, window_size: int):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
x: (B, H, W, C)
|
|
|
|
@ -204,7 +204,7 @@ def window_partition(x, window_size: int):
|
|
|
|
|
return windows
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def window_reverse(windows, window_size: int, H: int, W: int):
|
|
|
|
|
def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
windows: (num_windows*B, window_size, window_size, C)
|
|
|
|
@ -244,7 +244,7 @@ class WindowAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
def forward(self, x : Tensor):
|
|
|
|
|
B_, N, C = x.shape
|
|
|
|
|
|
|
|
|
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
@ -303,7 +303,7 @@ class SpatialBlock(nn.Module):
|
|
|
|
|
act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
|
|
|
|
|
H, W = size
|
|
|
|
|
B, L, C = x.shape
|
|
|
|
@ -341,13 +341,14 @@ class SpatialBlock(nn.Module):
|
|
|
|
|
x = self.cpe[1](x, size)
|
|
|
|
|
if self.ffn:
|
|
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
|
|
return x, size
|
|
|
|
|
return x : Tensor, size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DaViT(nn.Module):
|
|
|
|
|
r""" DaViT
|
|
|
|
|
A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645
|
|
|
|
|
Supports arbitrary input sizes and pyramid feature extraction
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
in_chans (int): Number of input image channels. Default: 3
|
|
|
|
@ -526,7 +527,7 @@ class DaViTFeatures(DaViT):
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (0, 1, 2, 3)))
|
|
|
|
|
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
|
|
|
|
|
|
|
|
|
|
def forward(self, x) -> List[Tensor]:
|
|
|
|
|
return self.forward_pyramid_features(x)
|
|
|
|
|