|
|
|
@ -20,12 +20,20 @@ Some implementations are modified from timm (https://github.com/rwightman/pytorc
|
|
|
|
|
from functools import partial
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from timm.layers import trunc_normal_, DropPath
|
|
|
|
|
from ._registry import register_model
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from timm.layers import trunc_normal_, DropPath
|
|
|
|
|
from timm.layers.helpers import to_2tuple
|
|
|
|
|
from ._builder import build_model_with_cfg
|
|
|
|
|
from ._features import FeatureInfo
|
|
|
|
|
from ._features_fx import register_notrace_function
|
|
|
|
|
from ._manipulate import checkpoint_seq
|
|
|
|
|
from ._pretrained import generate_default_cfgs
|
|
|
|
|
from ._registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['MetaFormer']
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
|
return {
|
|
|
|
|
'url': url,
|
|
|
|
@ -187,6 +195,7 @@ default_cfgs = {
|
|
|
|
|
num_classes=21841),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cfgs_v2 = generate_default_cfgs(default_cfgs)
|
|
|
|
|
|
|
|
|
|
class Downsampling(nn.Module):
|
|
|
|
|
"""
|
|
|
|
@ -592,16 +601,17 @@ class MetaFormer(nn.Module):
|
|
|
|
|
cur = 0
|
|
|
|
|
for i in range(num_stage):
|
|
|
|
|
stage = nn.Sequential(
|
|
|
|
|
*[MetaFormerBlock(dim=dims[i],
|
|
|
|
|
token_mixer=token_mixers[i],
|
|
|
|
|
mlp=mlps[i],
|
|
|
|
|
norm_layer=norm_layers[i],
|
|
|
|
|
drop_path=dp_rates[cur + j],
|
|
|
|
|
layer_scale_init_value=layer_scale_init_values[i],
|
|
|
|
|
res_scale_init_value=res_scale_init_values[i],
|
|
|
|
|
downsample_layers[i],
|
|
|
|
|
*[MetaFormerBlock(
|
|
|
|
|
dim=dims[i],
|
|
|
|
|
token_mixer=token_mixers[i],
|
|
|
|
|
mlp=mlps[i],
|
|
|
|
|
norm_layer=norm_layers[i],
|
|
|
|
|
drop_path=dp_rates[cur + j],
|
|
|
|
|
layer_scale_init_value=layer_scale_init_values[i],
|
|
|
|
|
res_scale_init_value=res_scale_init_values[i],
|
|
|
|
|
) for j in range(depths[i])]
|
|
|
|
|
)
|
|
|
|
|
stages.append(downsample_layers[i])
|
|
|
|
|
stages.append(stage)
|
|
|
|
|
cur += depths[i]
|
|
|
|
|
|
|
|
|
@ -639,8 +649,20 @@ class MetaFormer(nn.Module):
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_metaformer(variant, pretrained=False, **kwargs):
|
|
|
|
|
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2))))
|
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
|
|
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
MetaFormer,
|
|
|
|
|
variant,
|
|
|
|
|
pretrained,
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|
feature_cfg=dict(flatten_sequential=True, out_indices = out_indices),
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
'''
|
|
|
|
|
@register_model
|
|
|
|
|
def identityformer_s12(pretrained=False, **kwargs):
|
|
|
|
|
model = MetaFormer(
|
|
|
|
@ -656,6 +678,17 @@ def identityformer_s12(pretrained=False, **kwargs):
|
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def identityformer_s12(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
depths=[2, 2, 6, 2],
|
|
|
|
|
dims=[64, 128, 320, 512],
|
|
|
|
|
token_mixers=nn.Identity,
|
|
|
|
|
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
|
|
|
|
**kwargs)
|
|
|
|
|
return _create_metaformer('identityformer_s12', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def identityformer_s24(pretrained=False, **kwargs):
|
|
|
|
|