From c8559878fba6015629834fd22717080a20fff57a Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 6 Dec 2022 19:44:22 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 444f21f3..17898160 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -113,7 +113,7 @@ class PatchEmbed(nn.Module): padding=to_2tuple(pad)) self.norm = nn.LayerNorm(in_chans) - def forward(self, x, size): + def forward(self, x, size: Tuple[int, int]): H, W = size dim = len(x.shape) if dim == 3: @@ -186,7 +186,7 @@ class ChannelBlock(nn.Module): hidden_features=mlp_hidden_dim, act_layer=act_layer) - def forward(self, x, size): + def forward(self, x, size: Tuple[int, int]): x = self.cpe[0](x, size) cur = self.norm1(x) cur = self.attn(cur) @@ -310,7 +310,7 @@ class SpatialBlock(nn.Module): hidden_features=mlp_hidden_dim, act_layer=act_layer) - def forward(self, x, size): + def forward(self, x, size: Tuple[int, int]): H, W = size B, L, C = x.shape assert L == H * W, "input feature has wrong size"