|
|
|
@ -89,7 +89,6 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
stride=patch_size,
|
|
|
|
|
padding=(3, 3))
|
|
|
|
|
self.norm = nn.LayerNorm(embed_dim)
|
|
|
|
|
self.norm_after = True
|
|
|
|
|
if patch_size[0] == 2:
|
|
|
|
|
kernel = 3 if overlapped else 2
|
|
|
|
|
pad = 1 if overlapped else 0
|
|
|
|
@ -100,17 +99,13 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
stride=patch_size,
|
|
|
|
|
padding=to_2tuple(pad))
|
|
|
|
|
self.norm = nn.LayerNorm(in_chans)
|
|
|
|
|
self.norm_after = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_notrace_function # reason: dim in control sequence
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
H, W = size
|
|
|
|
|
in_shape = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# norm_after variable used as a workaround to original len(x.shape) == 3
|
|
|
|
|
if self.norm_after == False:
|
|
|
|
|
B, HW, C = in_shape
|
|
|
|
|
dim = len(x.shape)
|
|
|
|
|
if dim == 3:
|
|
|
|
|
B, HW, C = x.shape
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
x = x.reshape(B,
|
|
|
|
|
H,
|
|
|
|
@ -118,14 +113,15 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
C).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
|
|
|
|
|
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
|
|
|
|
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
|
|
|
|
if W % self.patch_size[1] != 0:
|
|
|
|
|
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
|
|
|
|
if H % self.patch_size[0] != 0:
|
|
|
|
|
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
|
|
|
|
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
newsize = (x.size(2), x.size(3))
|
|
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
if self.norm_after == True:
|
|
|
|
|
if dim == 4:
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
return x, newsize
|
|
|
|
|
|
|
|
|
|