@ -236,9 +236,13 @@ class Downsampling(nn.Module):
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
def forward(self, x):
print(x.shape)
x = self.pre_norm(x)
x = self.conv(x)
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class Scale(nn.Module):