Update metaformers.py

pull/1647/head
Fredo Guan 2 years ago
parent ec202b4d16
commit d90ed530dc

@ -469,14 +469,19 @@ class MetaFormerBlock(nn.Module):
Implementation of one MetaFormer block. Implementation of one MetaFormer block.
""" """
def __init__(self, dim, def __init__(self, dim,
token_mixer=nn.Identity, mlp=Mlp, token_mixer=nn.Identity,
mlp=Mlp,
norm_layer=nn.LayerNorm, norm_layer=nn.LayerNorm,
drop=0., drop_path=0., drop=0., drop_path=0.,
layer_scale_init_value=None, res_scale_init_value=None layer_scale_init_value=None,
res_scale_init_value=None,
downsample = nn.Identity()
): ):
super().__init__() super().__init__()
self.downsample = nn.Identity()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.token_mixer = token_mixer(dim=dim, drop=drop) self.token_mixer = token_mixer(dim=dim, drop=drop)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -494,6 +499,7 @@ class MetaFormerBlock(nn.Module):
if res_scale_init_value else nn.Identity() if res_scale_init_value else nn.Identity()
def forward(self, x): def forward(self, x):
x = self.downsample(x)
x = self.res_scale1(x) + \ x = self.res_scale1(x) + \
self.layer_scale1( self.layer_scale1(
self.drop_path1( self.drop_path1(
@ -600,18 +606,18 @@ class MetaFormer(nn.Module):
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
cur = 0 cur = 0
for i in range(num_stage): for i in range(num_stage):
stage = nn.Sequential( stage = nn.Sequential(*[MetaFormerBlock(
downsample_layers[i], dim=dims[i],
*[MetaFormerBlock( token_mixer=token_mixers[i],
dim=dims[i], mlp=mlps[i],
token_mixer=token_mixers[i], norm_layer=norm_layers[i],
mlp=mlps[i], drop_path=dp_rates[cur + j],
norm_layer=norm_layers[i], layer_scale_init_value=layer_scale_init_values[i],
drop_path=dp_rates[cur + j], res_scale_init_value=res_scale_init_values[i],
layer_scale_init_value=layer_scale_init_values[i], downsample = downsample_layers[i]
res_scale_init_value=res_scale_init_values[i],
) for j in range(depths[i])] ) for j in range(depths[i])]
) )
stages.append(stage) stages.append(stage)
cur += depths[i] cur += depths[i]
@ -649,6 +655,17 @@ class MetaFormer(nn.Module):
x = self.head(x) x = self.head(x)
return x return x
def checkpoint_filter_fn(state_dict, model):
import re
out_dict = {}
for k, v in state_dict.items():
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
out_dict[k] = v
return out_dict
def _create_metaformer(variant, pretrained=False, **kwargs): def _create_metaformer(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2)))) default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2))))
out_indices = kwargs.pop('out_indices', default_out_indices) out_indices = kwargs.pop('out_indices', default_out_indices)

Loading…
Cancel
Save