Update davit.py

pull/1630/head
Fredo Guan 2 years ago
parent 55cfdd7b0d
commit 7dd25eb402

@ -102,7 +102,7 @@ class ConvPosEnc(nn.Module):
B, C, H, W = x.shape
#feat = x.transpose(1, 2).view(B, C, H, W)
feat = self.proj(feat)
feat = self.proj(x)
if self.normtype == 'batch':
feat = self.norm(feat).flatten(2).transpose(1, 2)
elif self.normtype == 'layer':
@ -216,7 +216,7 @@ class PatchEmbed(nn.Module):
def forward(self, x : Tensor):
B, C, H, W = x.shape
if self.norm.normalized_shape[0] == self.in_chans:
x = self.norm(x.flatten(2).transpose(1, 2)).transpose(1, 2).view(B, C, H, W)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))

Loading…
Cancel
Save