diff --git a/timm/models/davit.py b/timm/models/davit.py index be83e53d..273ebcd0 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -52,7 +52,7 @@ class ConvPosEnc(nn.Module): self.norm = nn.LayerNorm(dim) self.activation = nn.GELU() if act else nn.Identity() - def forward(self, x, size: Tuple[int, int]): + def forward(self, x : Tensor, size: Tuple[int, int]): B, N, C = x.shape H, W = size assert N == H * W @@ -104,9 +104,9 @@ class PatchEmbed(nn.Module): self.norm = nn.LayerNorm(in_chans) - def forward(self, x, size: Tuple[int, int]): + def forward(self, x : Tensor, size: Tuple[int, int]): H, W = size - dim = len(x.shape) + dim = x.dim() if dim == 3: B, HW, C = x.shape x = self.norm(x) @@ -140,7 +140,7 @@ class ChannelAttention(nn.Module): self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) - def forward(self, x): + def forward(self, x : Tensor): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) @@ -178,7 +178,7 @@ class ChannelBlock(nn.Module): act_layer=act_layer) - def forward(self, x, size: Tuple[int, int]): + def forward(self, x : Tensor, size: Tuple[int, int]): x = self.cpe[0](x, size) cur = self.norm1(x) cur = self.attn(cur) @@ -190,7 +190,7 @@ class ChannelBlock(nn.Module): return x, size -def window_partition(x, window_size: int): +def window_partition(x : Tensor, window_size: int): """ Args: x: (B, H, W, C) @@ -204,7 +204,7 @@ def window_partition(x, window_size: int): return windows -def window_reverse(windows, window_size: int, H: int, W: int): +def window_reverse(windows : Tensor, window_size: int, H: int, W: int): """ Args: windows: (num_windows*B, window_size, window_size, C) @@ -244,7 +244,7 @@ class WindowAttention(nn.Module): self.softmax = nn.Softmax(dim=-1) - def forward(self, x): + def forward(self, x : Tensor): B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) @@ -303,7 +303,7 @@ class SpatialBlock(nn.Module): act_layer=act_layer) - def forward(self, x, size: Tuple[int, int]): + def forward(self, x : Tensor, size: Tuple[int, int]): H, W = size B, L, C = x.shape @@ -341,13 +341,14 @@ class SpatialBlock(nn.Module): x = self.cpe[1](x, size) if self.ffn: x = x + self.drop_path(self.mlp(self.norm2(x))) - return x, size + return x : Tensor, size class DaViT(nn.Module): r""" DaViT A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645 + Supports arbitrary input sizes and pyramid feature extraction Args: in_chans (int): Number of input image channels. Default: 3 @@ -526,7 +527,7 @@ class DaViTFeatures(DaViT): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (0, 1, 2, 3))) + self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3))) def forward(self, x) -> List[Tensor]: return self.forward_pyramid_features(x)