|
|
|
@ -92,7 +92,7 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
stride=patch_size,
|
|
|
|
|
padding=(3, 3))
|
|
|
|
|
self.norm = nn.LayerNorm(embed_dim)
|
|
|
|
|
self.norm_after = False
|
|
|
|
|
self.norm_after = True
|
|
|
|
|
if patch_size[0] == 2:
|
|
|
|
|
kernel = 3 if overlapped else 2
|
|
|
|
|
pad = 1 if overlapped else 0
|
|
|
|
@ -103,7 +103,7 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
stride=patch_size,
|
|
|
|
|
padding=to_2tuple(pad))
|
|
|
|
|
self.norm = nn.LayerNorm(in_chans)
|
|
|
|
|
self.norm_after = True
|
|
|
|
|
self.norm_after = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|