diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index fe5622c6..9bb85c44 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -197,7 +197,7 @@ default_cfgs = { } cfgs_v2 = generate_default_cfgs(default_cfgs) - +''' class Downsampling(nn.Module): """ Downsampling implemented by a layer of convolution. @@ -221,7 +221,26 @@ class Downsampling(nn.Module): x = x.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C] x = self.post_norm(x) return x +''' +class Downsampling(nn.Module): + """ + Downsampling implemented by a layer of convolution. + """ + def __init__(self, in_channels, out_channels, + kernel_size, stride=1, padding=0, + pre_norm=None, post_norm=None): + super().__init__() + self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() + self.pre_permute = pre_permute + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, + stride=stride, padding=padding) + self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() + def forward(self, x): + x = self.pre_norm(x) + x = self.conv(x) + x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x class Scale(nn.Module): """ @@ -544,6 +563,8 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): + B, C, H, W = x.shape + x = x.view(B, H, W, C) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -556,6 +577,7 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) + x = x.view(B, C, H, W) return x