|
|
|
@ -92,7 +92,7 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
stride=patch_size,
|
|
|
|
|
padding=(3, 3))
|
|
|
|
|
self.norm = nn.LayerNorm(embed_dim)
|
|
|
|
|
self.norm_after = True
|
|
|
|
|
self.norm_after = False
|
|
|
|
|
if patch_size[0] == 2:
|
|
|
|
|
kernel = 3 if overlapped else 2
|
|
|
|
|
pad = 1 if overlapped else 0
|
|
|
|
@ -103,7 +103,7 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
stride=patch_size,
|
|
|
|
|
padding=to_2tuple(pad))
|
|
|
|
|
self.norm = nn.LayerNorm(in_chans)
|
|
|
|
|
self.norm_after = False
|
|
|
|
|
self.norm_after = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
@ -111,10 +111,9 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
#dim = x.dim()
|
|
|
|
|
#if dim == 3:
|
|
|
|
|
in_shape = x.shape
|
|
|
|
|
B = in_shape[0]
|
|
|
|
|
C = in_shape[-1]
|
|
|
|
|
|
|
|
|
|
if self.norm_after == False:
|
|
|
|
|
B, HW, C = in_shape
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
x = x.reshape(B,
|
|
|
|
|
H,
|
|
|
|
|