Update metaformers.py

pull/1647/head
Fredo Guan 3 years ago
parent 926d886527
commit 7f149f31d4

@ -521,14 +521,11 @@ class MetaFormerBlock(nn.Module):
norm_layer=nn.LayerNorm,
drop=0., drop_path=0.,
layer_scale_init_value=None,
res_scale_init_value=None,
downsample = nn.Identity()
res_scale_init_value=None
):
super().__init__()
self.downsample = downsample
self.norm1 = norm_layer(dim)
self.token_mixer = token_mixer(dim=dim, drop=drop)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -546,7 +543,6 @@ class MetaFormerBlock(nn.Module):
if res_scale_init_value else nn.Identity()
def forward(self, x):
x = self.downsample(x)
x = self.res_scale1(x) + \
self.layer_scale1(
self.drop_path1(
@ -653,18 +649,19 @@ class MetaFormer(nn.Module):
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
cur = 0
for i in range(num_stage):
stage = nn.Sequential(*[MetaFormerBlock(
stage = nn.Sequential(OrderedDict[
('downsample', downsample_layers[i]),
('blocks', nn.Sequential(*[MetaFormerBlock(
dim=dims[i],
token_mixer=token_mixers[i],
mlp=mlps[i],
norm_layer=norm_layers[i],
drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_values[i],
res_scale_init_value=res_scale_init_values[i],
downsample = downsample_layers[i]
) for j in range(depths[i])]
res_scale_init_value=res_scale_init_values[i]
) for j in range(depths[i])])
)]
)
stages.append(stage)
cur += depths[i]

Loading…
Cancel
Save