Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 30503a383a
commit 8bb6b3687b

@ -108,8 +108,6 @@ class PatchEmbed(nn.Module):
def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size
#dim = x.dim()
#if dim == 3:
in_shape = x.shape
if self.norm_after == False:
@ -128,7 +126,6 @@ class PatchEmbed(nn.Module):
x = self.proj(x)
newsize = (x.size(2), x.size(3))
x = x.flatten(2).transpose(1, 2)
#if dim == 4:
if self.norm_after == True:
x = self.norm(x)
return x, newsize

Loading…
Cancel
Save