Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 9d59ae5b63
commit 5e44a8109c

@ -52,7 +52,7 @@ class ConvPosEnc(nn.Module):
self.norm = nn.LayerNorm(dim)
self.activation = nn.GELU() if act else nn.Identity()
def forward(self, x, size: Tuple[int, int]):
def forward(self, x : Tensor, size: Tuple[int, int]):
B, N, C = x.shape
H, W = size
assert N == H * W
@ -104,9 +104,9 @@ class PatchEmbed(nn.Module):
self.norm = nn.LayerNorm(in_chans)
def forward(self, x, size: Tuple[int, int]):
def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size
dim = len(x.shape)
dim = x.dim()
if dim == 3:
B, HW, C = x.shape
x = self.norm(x)
@ -140,7 +140,7 @@ class ChannelAttention(nn.Module):
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
def forward(self, x : Tensor):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
@ -178,7 +178,7 @@ class ChannelBlock(nn.Module):
act_layer=act_layer)
def forward(self, x, size: Tuple[int, int]):
def forward(self, x : Tensor, size: Tuple[int, int]):
x = self.cpe[0](x, size)
cur = self.norm1(x)
cur = self.attn(cur)
@ -190,7 +190,7 @@ class ChannelBlock(nn.Module):
return x, size
def window_partition(x, window_size: int):
def window_partition(x : Tensor, window_size: int):
"""
Args:
x: (B, H, W, C)
@ -204,7 +204,7 @@ def window_partition(x, window_size: int):
return windows
def window_reverse(windows, window_size: int, H: int, W: int):
def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
@ -244,7 +244,7 @@ class WindowAttention(nn.Module):
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
def forward(self, x : Tensor):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
@ -303,7 +303,7 @@ class SpatialBlock(nn.Module):
act_layer=act_layer)
def forward(self, x, size: Tuple[int, int]):
def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size
B, L, C = x.shape
@ -341,13 +341,14 @@ class SpatialBlock(nn.Module):
x = self.cpe[1](x, size)
if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, size
return x : Tensor, size
class DaViT(nn.Module):
r""" DaViT
A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645
Supports arbitrary input sizes and pyramid feature extraction
Args:
in_chans (int): Number of input image channels. Default: 3
@ -526,7 +527,7 @@ class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (0, 1, 2, 3)))
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x)

Loading…
Cancel
Save