diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 30cba4cc..f5e12c5a 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -1,11 +1,10 @@ - - """ +Poolformer from MetaFormer is Actually What You Need for Vision https://arxiv.org/abs/2111.11418 MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2, -ConvFormer and CAFormer. +ConvFormer, and CAFormer as per https://arxiv.org/abs/2210.13452 -original copyright below +Adapted from https://github.com/sail-sg/metaformer, original copyright below """ # Copyright 2022 Garena Online Private Limited @@ -155,6 +154,7 @@ class Attention(nn.Module): class RandomMixing(nn.Module): def __init__(self, num_tokens=196, **kwargs): super().__init__() + # FIXME no grad breaks tests self.random_matrix = nn.parameter.Parameter( data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), requires_grad=False) @@ -367,8 +367,6 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - #B, C, H, W = x.shape - #x = x.view(B, H, W, C) x = x.permute(0, 2, 3, 1) x = self.res_scale1(x) + \ self.layer_scale1( @@ -382,7 +380,6 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) - #x = x.view(B, C, H, W) x = x.permute(0, 3, 1, 2) return x @@ -396,7 +393,6 @@ class MetaFormer(nn.Module): num_classes (int): Number of classes for classification head. Default: 1000. depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2]. dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512]. - downsample_layers: (list or tuple): Downsampling layers before each stage. token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity. mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp. norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False). @@ -415,7 +411,6 @@ class MetaFormer(nn.Module): num_classes=1000, depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], - #downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), token_mixers=nn.Identity, mlps=Mlp, @@ -445,15 +440,7 @@ class MetaFormer(nn.Module): dims = [dims] self.num_stages = len(depths) - ''' - if not isinstance(downsample_layers, (list, tuple)): - downsample_layers = [downsample_layers] * self.num_stages - down_dims = [in_chans] + dims - - downsample_layers = nn.ModuleList( - [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(self.num_stages)] - ) - ''' + if not isinstance(token_mixers, (list, tuple)): token_mixers = [token_mixers] * self.num_stages