|
|
@ -1,11 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
Poolformer from MetaFormer is Actually What You Need for Vision https://arxiv.org/abs/2111.11418
|
|
|
|
|
|
|
|
|
|
|
|
MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2,
|
|
|
|
MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2,
|
|
|
|
ConvFormer and CAFormer.
|
|
|
|
ConvFormer, and CAFormer as per https://arxiv.org/abs/2210.13452
|
|
|
|
|
|
|
|
|
|
|
|
original copyright below
|
|
|
|
Adapted from https://github.com/sail-sg/metaformer, original copyright below
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# Copyright 2022 Garena Online Private Limited
|
|
|
|
# Copyright 2022 Garena Online Private Limited
|
|
|
@ -155,6 +154,7 @@ class Attention(nn.Module):
|
|
|
|
class RandomMixing(nn.Module):
|
|
|
|
class RandomMixing(nn.Module):
|
|
|
|
def __init__(self, num_tokens=196, **kwargs):
|
|
|
|
def __init__(self, num_tokens=196, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
# FIXME no grad breaks tests
|
|
|
|
self.random_matrix = nn.parameter.Parameter(
|
|
|
|
self.random_matrix = nn.parameter.Parameter(
|
|
|
|
data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1),
|
|
|
|
data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1),
|
|
|
|
requires_grad=False)
|
|
|
|
requires_grad=False)
|
|
|
@ -367,8 +367,6 @@ class MetaFormerBlock(nn.Module):
|
|
|
|
if res_scale_init_value else nn.Identity()
|
|
|
|
if res_scale_init_value else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
#B, C, H, W = x.shape
|
|
|
|
|
|
|
|
#x = x.view(B, H, W, C)
|
|
|
|
|
|
|
|
x = x.permute(0, 2, 3, 1)
|
|
|
|
x = x.permute(0, 2, 3, 1)
|
|
|
|
x = self.res_scale1(x) + \
|
|
|
|
x = self.res_scale1(x) + \
|
|
|
|
self.layer_scale1(
|
|
|
|
self.layer_scale1(
|
|
|
@ -382,7 +380,6 @@ class MetaFormerBlock(nn.Module):
|
|
|
|
self.mlp(self.norm2(x))
|
|
|
|
self.mlp(self.norm2(x))
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
#x = x.view(B, C, H, W)
|
|
|
|
|
|
|
|
x = x.permute(0, 3, 1, 2)
|
|
|
|
x = x.permute(0, 3, 1, 2)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
@ -396,7 +393,6 @@ class MetaFormer(nn.Module):
|
|
|
|
num_classes (int): Number of classes for classification head. Default: 1000.
|
|
|
|
num_classes (int): Number of classes for classification head. Default: 1000.
|
|
|
|
depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2].
|
|
|
|
depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2].
|
|
|
|
dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512].
|
|
|
|
dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512].
|
|
|
|
downsample_layers: (list or tuple): Downsampling layers before each stage.
|
|
|
|
|
|
|
|
token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity.
|
|
|
|
token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity.
|
|
|
|
mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp.
|
|
|
|
mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp.
|
|
|
|
norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False).
|
|
|
|
norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False).
|
|
|
@ -415,7 +411,6 @@ class MetaFormer(nn.Module):
|
|
|
|
num_classes=1000,
|
|
|
|
num_classes=1000,
|
|
|
|
depths=[2, 2, 6, 2],
|
|
|
|
depths=[2, 2, 6, 2],
|
|
|
|
dims=[64, 128, 320, 512],
|
|
|
|
dims=[64, 128, 320, 512],
|
|
|
|
#downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES,
|
|
|
|
|
|
|
|
downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6),
|
|
|
|
downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6),
|
|
|
|
token_mixers=nn.Identity,
|
|
|
|
token_mixers=nn.Identity,
|
|
|
|
mlps=Mlp,
|
|
|
|
mlps=Mlp,
|
|
|
@ -445,15 +440,7 @@ class MetaFormer(nn.Module):
|
|
|
|
dims = [dims]
|
|
|
|
dims = [dims]
|
|
|
|
|
|
|
|
|
|
|
|
self.num_stages = len(depths)
|
|
|
|
self.num_stages = len(depths)
|
|
|
|
'''
|
|
|
|
|
|
|
|
if not isinstance(downsample_layers, (list, tuple)):
|
|
|
|
|
|
|
|
downsample_layers = [downsample_layers] * self.num_stages
|
|
|
|
|
|
|
|
down_dims = [in_chans] + dims
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
downsample_layers = nn.ModuleList(
|
|
|
|
|
|
|
|
[downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(self.num_stages)]
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
if not isinstance(token_mixers, (list, tuple)):
|
|
|
|
if not isinstance(token_mixers, (list, tuple)):
|
|
|
|
token_mixers = [token_mixers] * self.num_stages
|
|
|
|
token_mixers = [token_mixers] * self.num_stages
|
|
|
|
|
|
|
|
|
|
|
|