Update metaformers.py

pull/1647/head
Fredo Guan 2 years ago
parent 8568bc7b6a
commit 464ae0fe50

@ -197,7 +197,7 @@ default_cfgs = {
}
cfgs_v2 = generate_default_cfgs(default_cfgs)
'''
class Downsampling(nn.Module):
"""
Downsampling implemented by a layer of convolution.
@ -213,13 +213,13 @@ class Downsampling(nn.Module):
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
def forward(self, x):
x = self.pre_norm(x)
if self.pre_permute:
# if take [B, H, W, C] as input, permute it to [B, C, H, W]
x = x.permute(0, 3, 1, 2)
x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.conv(x)
x = x.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
x = self.post_norm(x)
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
'''
class Downsampling(nn.Module):
@ -244,7 +244,7 @@ class Downsampling(nn.Module):
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
print(x.shape)
return x
'''
class Scale(nn.Module):
"""
Scale vector by element multiplications.

Loading…
Cancel
Save