Update metaformers.py

pull/1647/head
Fredo Guan 3 years ago
parent 8568bc7b6a
commit 464ae0fe50

@ -197,7 +197,7 @@ default_cfgs = {
} }
cfgs_v2 = generate_default_cfgs(default_cfgs) cfgs_v2 = generate_default_cfgs(default_cfgs)
'''
class Downsampling(nn.Module): class Downsampling(nn.Module):
""" """
Downsampling implemented by a layer of convolution. Downsampling implemented by a layer of convolution.
@ -213,13 +213,13 @@ class Downsampling(nn.Module):
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
def forward(self, x): def forward(self, x):
x = self.pre_norm(x)
if self.pre_permute: if self.pre_permute:
# if take [B, H, W, C] as input, permute it to [B, C, H, W] # if take [B, H, W, C] as input, permute it to [B, C, H, W]
x = x.permute(0, 3, 1, 2) x = x.permute(0, 3, 1, 2)
x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.conv(x) x = self.conv(x)
x = x.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C] x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.post_norm(x)
return x return x
''' '''
class Downsampling(nn.Module): class Downsampling(nn.Module):
@ -244,7 +244,7 @@ class Downsampling(nn.Module):
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
print(x.shape) print(x.shape)
return x return x
'''
class Scale(nn.Module): class Scale(nn.Module):
""" """
Scale vector by element multiplications. Scale vector by element multiplications.

Loading…
Cancel
Save