Update metaformers.py

pull/1647/head
Fredo Guan 2 years ago
parent ddfdb543dc
commit 8568bc7b6a

@ -236,9 +236,13 @@ class Downsampling(nn.Module):
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
def forward(self, x):
print(x.shape)
x = self.pre_norm(x)
print(x.shape)
x = self.conv(x)
print(x.shape)
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
print(x.shape)
return x
class Scale(nn.Module):

Loading…
Cancel
Save