From 55d4eb78a2a3fa6f9b00ad98c76437d6a0b3fb95 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 07:47:52 -0800 Subject: [PATCH] merge with poolformer, initial version --- timm/models/metaformers.py | 1070 ++++++++++-------------------------- 1 file changed, 286 insertions(+), 784 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 0020e789..d96d6de9 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -26,7 +26,7 @@ from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d +from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1 from timm.layers.helpers import to_2tuple from ._builder import build_model_with_cfg from ._features import FeatureInfo @@ -45,163 +45,177 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'classifier': 'head', 'first_conv': 'stages.0.downsample.conv', + 'classifier': 'head', 'first_conv': 'patch_embed.conv', **kwargs } -default_cfgs = { - 'identityformer_s12': _cfg( +cfgs_v2 = generate_default_cfgs({ + 'poolformerv1_s12.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar', + crop_pct=0.9), + 'poolformerv1_s24.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar', + crop_pct=0.9), + 'poolformerv1_s36.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s36.pth.tar', + crop_pct=0.9), + 'poolformerv1_m36.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m36.pth.tar', + crop_pct=0.95), + 'poolformerv1_m48.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar', + crop_pct=0.95), + + 'identityformer_s12.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'), - 'identityformer_s24': _cfg( + 'identityformer_s24.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'), - 'identityformer_s36': _cfg( + 'identityformer_s36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'), - 'identityformer_m36': _cfg( + 'identityformer_m36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'), - 'identityformer_m48': _cfg( + 'identityformer_m48.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'), - 'randformer_s12': _cfg( + 'randformer_s12.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'), - 'randformer_s24': _cfg( + 'randformer_s24.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'), - 'randformer_s36': _cfg( + 'randformer_s36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'), - 'randformer_m36': _cfg( + 'randformer_m36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'), - 'randformer_m48': _cfg( + 'randformer_m48.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'), - 'poolformerv2_s12': _cfg( + 'poolformerv2_s12.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'), - 'poolformerv2_s24': _cfg( + 'poolformerv2_s24.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'), - 'poolformerv2_s36': _cfg( + 'poolformerv2_s36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'), - 'poolformerv2_m36': _cfg( + 'poolformerv2_m36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'), - 'poolformerv2_m48': _cfg( + 'poolformerv2_m48.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'), - 'convformer_s18': _cfg( + 'convformer_s18.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), - 'convformer_s18_384': _cfg( + 'convformer_s18.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', input_size=(3, 384, 384)), - 'convformer_s18_in21ft1k': _cfg( + 'convformer_s18.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), - 'convformer_s18_384_in21ft1k': _cfg( + 'convformer_s18.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'convformer_s18_in21k': _cfg( + 'convformer_s18.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', num_classes=21841), - 'convformer_s36': _cfg( + 'convformer_s36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), - 'convformer_s36_384': _cfg( + 'convformer_s36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', input_size=(3, 384, 384)), - 'convformer_s36_in21ft1k': _cfg( + 'convformer_s36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), - 'convformer_s36_384_in21ft1k': _cfg( + 'convformer_s36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'convformer_s36_in21k': _cfg( + 'convformer_s36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', num_classes=21841), - 'convformer_m36': _cfg( + 'convformer_m36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), - 'convformer_m36_384': _cfg( + 'convformer_m36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', input_size=(3, 384, 384)), - 'convformer_m36_in21ft1k': _cfg( + 'convformer_m36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), - 'convformer_m36_384_in21ft1k': _cfg( + 'convformer_m36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'convformer_m36_in21k': _cfg( + 'convformer_m36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', num_classes=21841), - 'convformer_b36': _cfg( + 'convformer_b36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), - 'convformer_b36_384': _cfg( + 'convformer_b36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', input_size=(3, 384, 384)), - 'convformer_b36_in21ft1k': _cfg( + 'convformer_b36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), - 'convformer_b36_384_in21ft1k': _cfg( + 'convformer_b36_384.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'convformer_b36_in21k': _cfg( + 'convformer_b36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', num_classes=21841), - 'caformer_s18': _cfg( + 'caformer_s18.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), - 'caformer_s18_384': _cfg( + 'caformer_s18.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', input_size=(3, 384, 384)), - 'caformer_s18_in21ft1k': _cfg( + 'caformer_s18.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), - 'caformer_s18_384_in21ft1k': _cfg( + 'caformer_s18.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'caformer_s18_in21k': _cfg( + 'caformer_s18.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', num_classes=21841), - 'caformer_s36': _cfg( + 'caformer_s36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), - 'caformer_s36_384': _cfg( + 'caformer_s36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', input_size=(3, 384, 384)), - 'caformer_s36_in21ft1k': _cfg( + 'caformer_s36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), - 'caformer_s36_384_in21ft1k': _cfg( + 'caformer_s36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'caformer_s36_in21k': _cfg( + 'caformer_s36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', num_classes=21841), - 'caformer_m36': _cfg( + 'caformer_m36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), - 'caformer_m36_384': _cfg( + 'caformer_m36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', input_size=(3, 384, 384)), - 'caformer_m36_in21ft1k': _cfg( + 'caformer_m36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), - 'caformer_m36_384_in21ft1k': _cfg( + 'caformer_m36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'caformer_m36_in21k': _cfg( + 'caformer_m36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', num_classes=21841), - 'caformer_b36': _cfg( + 'caformer_b36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), - 'caformer_b36_384': _cfg( + 'caformer_b36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', input_size=(3, 384, 384)), - 'caformer_b36_in21ft1k': _cfg( + 'caformer_b36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), - 'caformer_b36_384_in21ft1k': _cfg( + 'caformer_b36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'caformer_b36_in21k': _cfg( + 'caformer_b36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', num_classes=21841), -} - -cfgs_v2 = generate_default_cfgs(default_cfgs) +}) class Downsampling(nn.Module): """ @@ -290,6 +304,11 @@ class StarReLU(nn.Module): def forward(self, x): return self.scale * self.relu(x)**2 + self.bias +class Conv2dChannelsLast(nn.Conv2d): + def forward(self, x): + x = x.permute(0, 3, 1, 2) + return self._conv_forward(x, self.weight, self.bias).permute(0, 2, 3, 1) + class Attention(nn.Module): """ @@ -492,19 +511,28 @@ class Pooling(nn.Module): class Mlp(nn.Module): """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks. - Mostly copied from timm. + Modified from standard timm implementation """ - def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False): + def __init__( + self, + dim, + mlp_ratio=4, + out_features=None, + act_layer=StarReLU, + mlp_fn=nn.Linear, + drop=0., + bias=False + ): super().__init__() in_features = dim out_features = out_features or in_features hidden_features = int(mlp_ratio * in_features) drop_probs = to_2tuple(drop) - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.fc1 = mlp_fn(in_features, hidden_features, bias=bias) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.fc2 = mlp_fn(hidden_features, out_features, bias=bias) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): @@ -516,6 +544,7 @@ class Mlp(nn.Module): return x + class MlpHead(nn.Module): """ MLP classification head """ @@ -539,18 +568,24 @@ class MlpHead(nn.Module): return x + class MetaFormerBlock(nn.Module): """ Implementation of one MetaFormer block. """ - def __init__(self, dim, - token_mixer=nn.Identity, - mlp=Mlp, - norm_layer=nn.LayerNorm, - drop=0., drop_path=0., - layer_scale_init_value=None, - res_scale_init_value=None - ): + def __init__( + self, + dim, + token_mixer=nn.Identity, + mlp=Mlp, + mlp_fn=nn.Linear, + mlp_act=StarReLU, + mlp_bias=False, + norm_layer=nn.LayerNorm, + drop=0., drop_path=0., + layer_scale_init_value=None, + res_scale_init_value=None + ): super().__init__() @@ -563,7 +598,13 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() self.norm2 = norm_layer(dim) - self.mlp = mlp(dim=dim, drop=drop) + self.mlp = mlp( + dim=dim, + drop=drop, + mlp_fn=mlp_fn, + act_layer=mlp_act, + bias=mlp_bias + ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) \ if layer_scale_init_value else nn.Identity() @@ -573,6 +614,7 @@ class MetaFormerBlock(nn.Module): def forward(self, x): B, C, H, W = x.shape x = x.view(B, H, W, C) + print(x.shape) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -588,23 +630,6 @@ class MetaFormerBlock(nn.Module): x = x.view(B, C, H, W) return x - -r""" -downsampling (stem) for the first stage is a layer of conv with k7, s4 and p2 -downsamplings for the last 3 stages is a layer of conv with k3, s2 and p1 -DOWNSAMPLE_LAYERS_FOUR_STAGES format: [Downsampling, Downsampling, Downsampling, Downsampling] -use `partial` to specify some arguments -""" -DOWNSAMPLE_LAYERS_FOUR_STAGES = [partial(Downsampling, - kernel_size=7, stride=4, padding=2, - post_norm=partial(LayerNormGeneral, bias=False, eps=1e-6) - )] + \ - [partial(Downsampling, - kernel_size=3, stride=2, padding=1, - pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=False - )]*3 - - class MetaFormer(nn.Module): r""" MetaFormer A PyTorch impl of : `MetaFormer Baselines for Vision` - @@ -628,22 +653,29 @@ class MetaFormer(nn.Module): output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6). head_fn: classification head. Default: nn.Linear. """ - def __init__(self, in_chans=3, num_classes=1000, - depths=[2, 2, 6, 2], - dims=[64, 128, 320, 512], - downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, - token_mixers=nn.Identity, - mlps=Mlp, - norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False), - drop_path_rate=0., - head_dropout=0.0, - layer_scale_init_values=None, - res_scale_init_values=[None, None, 1.0, 1.0], - output_norm=partial(nn.LayerNorm, eps=1e-6), - head_fn=nn.Linear, - global_pool = 'avg', - **kwargs, - ): + def __init__( + self, + in_chans=3, + num_classes=1000, + depths=[2, 2, 6, 2], + dims=[64, 128, 320, 512], + #downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, + downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), + token_mixers=nn.Identity, + mlps=Mlp, + mlp_fn=nn.Linear, + mlp_act = StarReLU, + mlp_bias=False, + norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False), + drop_path_rate=0., + head_dropout=0.0, + layer_scale_init_values=None, + res_scale_init_values=[None, None, 1.0, 1.0], + output_norm=partial(nn.LayerNorm, eps=1e-6), + head_fn=nn.Linear, + global_pool = 'avg', + **kwargs, + ): super().__init__() self.num_classes = num_classes self.head_fn = head_fn @@ -656,44 +688,66 @@ class MetaFormer(nn.Module): if not isinstance(dims, (list, tuple)): dims = [dims] - num_stage = len(depths) - self.num_stage = num_stage - + self.num_stages = len(depths) + ''' if not isinstance(downsample_layers, (list, tuple)): - downsample_layers = [downsample_layers] * num_stage + downsample_layers = [downsample_layers] * self.num_stages down_dims = [in_chans] + dims + downsample_layers = nn.ModuleList( - [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)] + [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(self.num_stages)] ) - + ''' if not isinstance(token_mixers, (list, tuple)): - token_mixers = [token_mixers] * num_stage + token_mixers = [token_mixers] * self.num_stages if not isinstance(mlps, (list, tuple)): - mlps = [mlps] * num_stage + mlps = [mlps] * self.num_stages if not isinstance(norm_layers, (list, tuple)): - norm_layers = [norm_layers] * num_stage + norm_layers = [norm_layers] * self.num_stages - dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] if not isinstance(layer_scale_init_values, (list, tuple)): - layer_scale_init_values = [layer_scale_init_values] * num_stage + layer_scale_init_values = [layer_scale_init_values] * self.num_stages if not isinstance(res_scale_init_values, (list, tuple)): - res_scale_init_values = [res_scale_init_values] * num_stage + res_scale_init_values = [res_scale_init_values] * self.num_stages self.grad_checkpointing = False self.feature_info = [] + dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + + self.patch_embed = Downsampling( + in_chans, + dims[0], + kernel_size=7, + stride=4, + padding=2, + post_norm=downsample_norm + ) + stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 - for i in range(num_stage): + for i in range(self.num_stages): stage = nn.Sequential(OrderedDict([ - ('downsample', downsample_layers[i]), + ('downsample', nn.Identity() if i == 0 else Downsampling( + dims[i-1], + dims[i], + kernel_size=3, + stride=2, + padding=1, + pre_norm=downsample_norm, + pre_permute=False + )), ('blocks', nn.Sequential(*[MetaFormerBlock( dim=dims[i], token_mixer=token_mixers[i], mlp=mlps[i], + mlp_fn=mlp_fn, + mlp_act=mlp_act, + mlp_bias=mlp_bias, norm_layer=norm_layers[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_values[i], @@ -762,6 +816,7 @@ class MetaFormer(nn.Module): return x def forward_features(self, x): + x = self.patch_embed(x) x = self.stages(x) @@ -773,13 +828,21 @@ class MetaFormer(nn.Module): return x def checkpoint_filter_fn(state_dict, model): - import re out_dict = {} for k, v in state_dict.items(): - + k = k.replace('proj', 'conv') + k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) + k = k.replace('network.1', 'downsample_layers.1') + k = k.replace('network.3', 'downsample_layers.2') + k = k.replace('network.5', 'downsample_layers.3') + k = k.replace('network.2', 'network.1') + k = k.replace('network.4', 'network.2') + k = k.replace('network.6', 'network.3') + k = k.replace('network', 'stages') k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) + k = k.replace('stages.0.downsample', 'patch_embed') out_dict[k] = v return out_dict @@ -797,23 +860,86 @@ def _create_metaformer(variant, pretrained=False, **kwargs): **kwargs) return model -''' + @register_model -def identityformer_s12(pretrained=False, **kwargs): - model = MetaFormer( +def poolformerv1_s12(pretrained=False, **kwargs): + model_kwargs = dict( depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], - token_mixers=nn.Identity, - norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + downsample_norm=None, + token_mixers=Pooling, + mlp_fn=partial(Conv2dChannelsLast, kernel_size=1), + mlp_act=nn.GELU, + mlp_bias=True, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=True), + layer_scale_init_values=1e-5, + res_scale_init_values=None, **kwargs) - model.default_cfg = default_cfgs['identityformer_s12'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('poolformerv1_s12', pretrained=pretrained, **model_kwargs) -''' +@register_model +def poolformerv1_s24(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[4, 4, 12, 4], + dims=[64, 128, 320, 512], + downsample_norm=None, + token_mixers=Pooling, + mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_act=nn.GELU, + mlp_bias=True, + norm_layers=GroupNorm1, + layer_scale_init_values=1e-5, + res_scale_init_values=None, + **kwargs) + return _create_metaformer('poolformerv1_s24', pretrained=pretrained, **model_kwargs) + +@register_model +def poolformerv1_s36(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[6, 6, 18, 6], + dims=[64, 128, 320, 512], + downsample_norm=None, + token_mixers=Pooling, + mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_act=nn.GELU, + mlp_bias=True, + norm_layers=GroupNorm1, + layer_scale_init_values=1e-6, + res_scale_init_values=None, + **kwargs) + return _create_metaformer('poolformerv1_s36', pretrained=pretrained, **model_kwargs) + +@register_model +def poolformerv1_m36(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[6, 6, 18, 6], + dims=[96, 192, 384, 768], + downsample_norm=None, + token_mixers=Pooling, + mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_act=nn.GELU, + mlp_bias=True, + norm_layers=GroupNorm1, + layer_scale_init_values=1e-6, + res_scale_init_values=None, + **kwargs) + return _create_metaformer('poolformerv1_m36', pretrained=pretrained, **model_kwargs) + +@register_model +def poolformerv1_m48(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[8, 8, 24, 8], + dims=[96, 192, 384, 768], + downsample_norm=None, + token_mixers=Pooling, + mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_act=nn.GELU, + mlp_bias=True, + norm_layers=GroupNorm1, + layer_scale_init_values=1e-6, + res_scale_init_values=None, + **kwargs) + return _create_metaformer('poolformerv1_m48', pretrained=pretrained, **model_kwargs) @register_model def identityformer_s12(pretrained=False, **kwargs): @@ -833,13 +959,7 @@ def identityformer_s24(pretrained=False, **kwargs): token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['identityformer_s24'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('identityformer_s24', pretrained=pretrained, **model_kwargs) @register_model def identityformer_s36(pretrained=False, **kwargs): @@ -849,13 +969,7 @@ def identityformer_s36(pretrained=False, **kwargs): token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['identityformer_s36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('identityformer_s36', pretrained=pretrained, **model_kwargs) @register_model def identityformer_m36(pretrained=False, **kwargs): @@ -865,13 +979,7 @@ def identityformer_m36(pretrained=False, **kwargs): token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['identityformer_m36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('identityformer_m36', pretrained=pretrained, **model_kwargs) @register_model def identityformer_m48(pretrained=False, **kwargs): @@ -881,13 +989,7 @@ def identityformer_m48(pretrained=False, **kwargs): token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['identityformer_m48'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('identityformer_m48', pretrained=pretrained, **model_kwargs) @register_model def randformer_s12(pretrained=False, **kwargs): @@ -897,13 +999,7 @@ def randformer_s12(pretrained=False, **kwargs): token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['randformer_s12'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('randformer_s12', pretrained=pretrained, **model_kwargs) @register_model def randformer_s24(pretrained=False, **kwargs): @@ -913,13 +1009,7 @@ def randformer_s24(pretrained=False, **kwargs): token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['randformer_s24'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('randformer_s24', pretrained=pretrained, **model_kwargs) @register_model def randformer_s36(pretrained=False, **kwargs): @@ -929,13 +1019,7 @@ def randformer_s36(pretrained=False, **kwargs): token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['randformer_s36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('randformer_s36', pretrained=pretrained, **model_kwargs) @register_model def randformer_m36(pretrained=False, **kwargs): @@ -945,13 +1029,7 @@ def randformer_m36(pretrained=False, **kwargs): token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['randformer_m36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('randformer_m36', pretrained=pretrained, **model_kwargs) @register_model def randformer_m48(pretrained=False, **kwargs): @@ -961,14 +1039,7 @@ def randformer_m48(pretrained=False, **kwargs): token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['randformer_m48'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - + return _create_metaformer('randformer_m48', pretrained=pretrained, **model_kwargs) @register_model def poolformerv2_s12(pretrained=False, **kwargs): @@ -978,13 +1049,7 @@ def poolformerv2_s12(pretrained=False, **kwargs): token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['poolformerv2_s12'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('poolformerv2_s12', pretrained=pretrained, **model_kwargs) @register_model def poolformerv2_s24(pretrained=False, **kwargs): @@ -994,12 +1059,8 @@ def poolformerv2_s24(pretrained=False, **kwargs): token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['poolformerv2_s24'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('poolformerv2_s24', pretrained=pretrained, **model_kwargs) + @register_model @@ -1010,12 +1071,8 @@ def poolformerv2_s36(pretrained=False, **kwargs): token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['poolformerv2_s36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('poolformerv2_s36', pretrained=pretrained, **model_kwargs) + @register_model @@ -1026,12 +1083,7 @@ def poolformerv2_m36(pretrained=False, **kwargs): token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['poolformerv2_m36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('poolformerv2_m36', pretrained=pretrained, **model_kwargs) @register_model @@ -1042,92 +1094,21 @@ def poolformerv2_m48(pretrained=False, **kwargs): token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['poolformerv2_m48'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('poolformerv2_m48', pretrained=pretrained, **model_kwargs) -@register_model -def convformer_s18(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s18'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - @register_model -def convformer_s18_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s18_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_s18_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s18_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_s18_384_in21ft1k(pretrained=False, **kwargs): +def convformer_s18(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['convformer_s18_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('convformer_s18', pretrained=pretrained, **model_kwargs) -@register_model -def convformer_s18_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s18_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model @register_model @@ -1138,76 +1119,7 @@ def convformer_s36(pretrained=False, **kwargs): token_mixers=SepConv, head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['convformer_s36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_s36_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s36_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_s36_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s36_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_s36_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s36_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_s36_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s36_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('convformer_s36', pretrained=pretrained, **model_kwargs) @register_model @@ -1218,76 +1130,8 @@ def convformer_m36(pretrained=False, **kwargs): token_mixers=SepConv, head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['convformer_m36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('convformer_m36', pretrained=pretrained, **model_kwargs) -@register_model -def convformer_m36_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_m36_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_m36_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_m36_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_m36_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_m36_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_m36_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_m36_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model @register_model @@ -1298,76 +1142,9 @@ def convformer_b36(pretrained=False, **kwargs): token_mixers=SepConv, head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['convformer_b36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_b36_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_b36_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_b36_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_b36_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('convformer_b36', pretrained=pretrained, **model_kwargs) -@register_model -def convformer_b36_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_b36_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_b36_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_b36_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model @register_model @@ -1378,76 +1155,8 @@ def caformer_s18(pretrained=False, **kwargs): token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['caformer_s18'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s18_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s18_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s18_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s18_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('caformer_s18', pretrained=pretrained, **model_kwargs) -@register_model -def caformer_s18_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s18_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s18_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s18_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model @register_model @@ -1458,76 +1167,7 @@ def caformer_s36(pretrained=False, **kwargs): token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['caformer_s36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s36_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s36_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s36_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s36_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s36_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s36_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s36_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s36_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('caformer_s36', pretrained=pretrained, **model_kwargs) @register_model @@ -1538,76 +1178,7 @@ def caformer_m36(pretrained=False, **kwargs): token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['caformer_m36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_m36_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_m36_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_m36_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_m36_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_m36_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_m36_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_m364_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_m364_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('caformer_m36', pretrained=pretrained, **model_kwargs) @register_model @@ -1618,73 +1189,4 @@ def caformer_b36(pretrained=False, **kwargs): token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['caformer_b36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_b36_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_b36_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_b36_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_b36_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_b36_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_b36_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_b36_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_b36_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model \ No newline at end of file + return _create_metaformer('caformer_b36', pretrained=pretrained, **model_kwargs)