From fb90cb1503798fcd5cae387e60dd3343e3c1cd71 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 20:43:53 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index edf6edbc..6e6fb994 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -475,11 +475,11 @@ class SpatialBlock(nn.Module): act_layer=act_layer) - def forward(self, x : Tensor, size: Tuple[int, int]): + def forward(self, x : Tensor): B, C, H, W = x.shape - shortcut = self.cpe1(x, size).flatten(2).transpose(1, 2) + shortcut = self.cpe1(x).flatten(2).transpose(1, 2) x = self.norm1(shortcut) x = x.view(B, H, W, C) @@ -508,13 +508,13 @@ class SpatialBlock(nn.Module): x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) - x = self.cpe2(x, size) + x = self.cpe2(x) if self.ffn: x = x + self.drop_path(self.mlp(self.norm2(x))) x = x.transpose(1, 2).view(B, C, H, W) - return x, size + return x class SpatialBlockOld(nn.Module): r""" Windows Block.