Update metaformers.py

pull/1647/head
Fredo Guan 2 years ago
parent 5d9cb3b943
commit 5a19034a99

@ -1,11 +1,10 @@
"""
Poolformer from MetaFormer is Actually What You Need for Vision https://arxiv.org/abs/2111.11418
MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2,
ConvFormer and CAFormer.
ConvFormer, and CAFormer as per https://arxiv.org/abs/2210.13452
original copyright below
Adapted from https://github.com/sail-sg/metaformer, original copyright below
"""
# Copyright 2022 Garena Online Private Limited
@ -155,6 +154,7 @@ class Attention(nn.Module):
class RandomMixing(nn.Module):
def __init__(self, num_tokens=196, **kwargs):
super().__init__()
# FIXME no grad breaks tests
self.random_matrix = nn.parameter.Parameter(
data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1),
requires_grad=False)
@ -367,8 +367,6 @@ class MetaFormerBlock(nn.Module):
if res_scale_init_value else nn.Identity()
def forward(self, x):
#B, C, H, W = x.shape
#x = x.view(B, H, W, C)
x = x.permute(0, 2, 3, 1)
x = self.res_scale1(x) + \
self.layer_scale1(
@ -382,7 +380,6 @@ class MetaFormerBlock(nn.Module):
self.mlp(self.norm2(x))
)
)
#x = x.view(B, C, H, W)
x = x.permute(0, 3, 1, 2)
return x
@ -396,7 +393,6 @@ class MetaFormer(nn.Module):
num_classes (int): Number of classes for classification head. Default: 1000.
depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2].
dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512].
downsample_layers: (list or tuple): Downsampling layers before each stage.
token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity.
mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp.
norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False).
@ -415,7 +411,6 @@ class MetaFormer(nn.Module):
num_classes=1000,
depths=[2, 2, 6, 2],
dims=[64, 128, 320, 512],
#downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES,
downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6),
token_mixers=nn.Identity,
mlps=Mlp,
@ -445,15 +440,7 @@ class MetaFormer(nn.Module):
dims = [dims]
self.num_stages = len(depths)
'''
if not isinstance(downsample_layers, (list, tuple)):
downsample_layers = [downsample_layers] * self.num_stages
down_dims = [in_chans] + dims
downsample_layers = nn.ModuleList(
[downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(self.num_stages)]
)
'''
if not isinstance(token_mixers, (list, tuple)):
token_mixers = [token_mixers] * self.num_stages

Loading…
Cancel
Save