Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 385e371204
commit 347386580f

@ -92,7 +92,7 @@ class PatchEmbed(nn.Module):
stride=patch_size,
padding=(3, 3))
self.norm = nn.LayerNorm(embed_dim)
self.norm_after = True
self.norm_after = False
if patch_size[0] == 2:
kernel = 3 if overlapped else 2
pad = 1 if overlapped else 0
@ -103,7 +103,7 @@ class PatchEmbed(nn.Module):
stride=patch_size,
padding=to_2tuple(pad))
self.norm = nn.LayerNorm(in_chans)
self.norm_after = False
self.norm_after = True
def forward(self, x : Tensor, size: Tuple[int, int]):
@ -111,15 +111,14 @@ class PatchEmbed(nn.Module):
#dim = x.dim()
#if dim == 3:
in_shape = x.shape
B = in_shape[0]
C = in_shape[-1]
if self.norm_after == False:
B, HW, C = in_shape
x = self.norm(x)
x = x.reshape(B,
H,
W,
C).permute(0, 3, 1, 2).contiguous()
x = x.reshape(B,
H,
W,
C).permute(0, 3, 1, 2).contiguous()
B, C, H, W = x.shape
if W % self.patch_size[1] != 0:

Loading…
Cancel
Save