|
|
@ -108,8 +108,6 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
H, W = size
|
|
|
|
H, W = size
|
|
|
|
#dim = x.dim()
|
|
|
|
|
|
|
|
#if dim == 3:
|
|
|
|
|
|
|
|
in_shape = x.shape
|
|
|
|
in_shape = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
if self.norm_after == False:
|
|
|
|
if self.norm_after == False:
|
|
|
@ -128,7 +126,6 @@ class PatchEmbed(nn.Module):
|
|
|
|
x = self.proj(x)
|
|
|
|
x = self.proj(x)
|
|
|
|
newsize = (x.size(2), x.size(3))
|
|
|
|
newsize = (x.size(2), x.size(3))
|
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
#if dim == 4:
|
|
|
|
|
|
|
|
if self.norm_after == True:
|
|
|
|
if self.norm_after == True:
|
|
|
|
x = self.norm(x)
|
|
|
|
x = self.norm(x)
|
|
|
|
return x, newsize
|
|
|
|
return x, newsize
|
|
|
|