From ec202b4d163e4942e62c27837c30d3b679c7526f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 23:17:25 -0800 Subject: [PATCH] Update metaformers.py --- timm/models/metaformers.py | 57 ++++++++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 12 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 99b9b53f..ff8bc0d5 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -20,12 +20,20 @@ Some implementations are modified from timm (https://github.com/rwightman/pytorc from functools import partial import torch import torch.nn as nn -from timm.layers import trunc_normal_, DropPath -from ._registry import register_model from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import trunc_normal_, DropPath from timm.layers.helpers import to_2tuple +from ._builder import build_model_with_cfg +from ._features import FeatureInfo +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model + +__all__ = ['MetaFormer'] + def _cfg(url='', **kwargs): return { 'url': url, @@ -187,6 +195,7 @@ default_cfgs = { num_classes=21841), } +cfgs_v2 = generate_default_cfgs(default_cfgs) class Downsampling(nn.Module): """ @@ -592,16 +601,17 @@ class MetaFormer(nn.Module): cur = 0 for i in range(num_stage): stage = nn.Sequential( - *[MetaFormerBlock(dim=dims[i], - token_mixer=token_mixers[i], - mlp=mlps[i], - norm_layer=norm_layers[i], - drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_values[i], - res_scale_init_value=res_scale_init_values[i], + downsample_layers[i], + *[MetaFormerBlock( + dim=dims[i], + token_mixer=token_mixers[i], + mlp=mlps[i], + norm_layer=norm_layers[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_values[i], + res_scale_init_value=res_scale_init_values[i], ) for j in range(depths[i])] ) - stages.append(downsample_layers[i]) stages.append(stage) cur += depths[i] @@ -639,8 +649,20 @@ class MetaFormer(nn.Module): x = self.head(x) return x - - +def _create_metaformer(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + + model = build_model_with_cfg( + MetaFormer, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices = out_indices), + **kwargs) + + return model +''' @register_model def identityformer_s12(pretrained=False, **kwargs): model = MetaFormer( @@ -656,6 +678,17 @@ def identityformer_s12(pretrained=False, **kwargs): model.load_state_dict(state_dict) return model +''' + +@register_model +def identityformer_s12(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[2, 2, 6, 2], + dims=[64, 128, 320, 512], + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + return _create_metaformer('identityformer_s12', pretrained=pretrained, **model_kwargs) @register_model def identityformer_s24(pretrained=False, **kwargs):