Update metaformers.py

pull/1647/head
Fredo Guan 2 years ago
parent 0bde1c1218
commit ec202b4d16

@ -20,12 +20,20 @@ Some implementations are modified from timm (https://github.com/rwightman/pytorc
from functools import partial
import torch
import torch.nn as nn
from timm.layers import trunc_normal_, DropPath
from ._registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath
from timm.layers.helpers import to_2tuple
from ._builder import build_model_with_cfg
from ._features import FeatureInfo
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
__all__ = ['MetaFormer']
def _cfg(url='', **kwargs):
return {
'url': url,
@ -187,6 +195,7 @@ default_cfgs = {
num_classes=21841),
}
cfgs_v2 = generate_default_cfgs(default_cfgs)
class Downsampling(nn.Module):
"""
@ -592,16 +601,17 @@ class MetaFormer(nn.Module):
cur = 0
for i in range(num_stage):
stage = nn.Sequential(
*[MetaFormerBlock(dim=dims[i],
token_mixer=token_mixers[i],
mlp=mlps[i],
norm_layer=norm_layers[i],
drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_values[i],
res_scale_init_value=res_scale_init_values[i],
downsample_layers[i],
*[MetaFormerBlock(
dim=dims[i],
token_mixer=token_mixers[i],
mlp=mlps[i],
norm_layer=norm_layers[i],
drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_values[i],
res_scale_init_value=res_scale_init_values[i],
) for j in range(depths[i])]
)
stages.append(downsample_layers[i])
stages.append(stage)
cur += depths[i]
@ -639,8 +649,20 @@ class MetaFormer(nn.Module):
x = self.head(x)
return x
def _create_metaformer(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
MetaFormer,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices = out_indices),
**kwargs)
return model
'''
@register_model
def identityformer_s12(pretrained=False, **kwargs):
model = MetaFormer(
@ -656,6 +678,17 @@ def identityformer_s12(pretrained=False, **kwargs):
model.load_state_dict(state_dict)
return model
'''
@register_model
def identityformer_s12(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 6, 2],
dims=[64, 128, 320, 512],
token_mixers=nn.Identity,
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
**kwargs)
return _create_metaformer('identityformer_s12', pretrained=pretrained, **model_kwargs)
@register_model
def identityformer_s24(pretrained=False, **kwargs):

Loading…
Cancel
Save