From 7dd25eb4027f24b29a1b88c6f7f74743e47aa2d5 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 20:54:39 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 04b72ace..2e3e33ba 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -102,7 +102,7 @@ class ConvPosEnc(nn.Module): B, C, H, W = x.shape #feat = x.transpose(1, 2).view(B, C, H, W) - feat = self.proj(feat) + feat = self.proj(x) if self.normtype == 'batch': feat = self.norm(feat).flatten(2).transpose(1, 2) elif self.normtype == 'layer': @@ -216,7 +216,7 @@ class PatchEmbed(nn.Module): def forward(self, x : Tensor): B, C, H, W = x.shape if self.norm.normalized_shape[0] == self.in_chans: - x = self.norm(x.flatten(2).transpose(1, 2)).transpose(1, 2).view(B, C, H, W) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) if W % self.patch_size[1] != 0: x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))