From 530709d299bc20e32f154f48b9369ebba26221e9 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 20:22:23 -0800 Subject: [PATCH 001/102] update --- timm/models/__init__.py | 1 + timm/models/metaformers.py | 1519 ++++++++++++++++++++++++++++++++++++ 2 files changed, 1520 insertions(+) create mode 100644 timm/models/metaformers.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 5ecc8915..1d00995a 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -26,6 +26,7 @@ from .inception_v3 import * from .inception_v4 import * from .levit import * from .maxxvit import * +from .metaformers import * from .mlp_mixer import * from .mobilenetv3 import * from .mobilevit import * diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py new file mode 100644 index 00000000..05adde19 --- /dev/null +++ b/timm/models/metaformers.py @@ -0,0 +1,1519 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2, +ConvFormer and CAFormer. +Some implementations are modified from timm (https://github.com/rwightman/pytorch-image-models). +""" +from functools import partial +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_, DropPath +from timm.models.registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers.helpers import to_2tuple + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': 1.0, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'identityformer_s12': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'), + 'identityformer_s24': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'), + 'identityformer_s36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'), + 'identityformer_m36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'), + 'identityformer_m48': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'), + + + 'randformer_s12': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'), + 'randformer_s24': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'), + 'randformer_s36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'), + 'randformer_m36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'), + 'randformer_m48': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'), + + 'poolformerv2_s12': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'), + 'poolformerv2_s24': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'), + 'poolformerv2_s36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'), + 'poolformerv2_m36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'), + 'poolformerv2_m48': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'), + + + + 'convformer_s18': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), + 'convformer_s18_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', + input_size=(3, 384, 384)), + 'convformer_s18_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), + 'convformer_s18_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_s18_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', + num_classes=21841), + + 'convformer_s36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), + 'convformer_s36_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', + input_size=(3, 384, 384)), + 'convformer_s36_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), + 'convformer_s36_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_s36_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', + num_classes=21841), + + 'convformer_m36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), + 'convformer_m36_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', + input_size=(3, 384, 384)), + 'convformer_m36_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), + 'convformer_m36_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_m36_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', + num_classes=21841), + + 'convformer_b36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), + 'convformer_b36_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', + input_size=(3, 384, 384)), + 'convformer_b36_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), + 'convformer_b36_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_b36_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', + num_classes=21841), + + + 'caformer_s18': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), + 'caformer_s18_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', + input_size=(3, 384, 384)), + 'caformer_s18_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), + 'caformer_s18_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_s18_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', + num_classes=21841), + + 'caformer_s36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), + 'caformer_s36_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', + input_size=(3, 384, 384)), + 'caformer_s36_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), + 'caformer_s36_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_s36_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', + num_classes=21841), + + 'caformer_m36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), + 'caformer_m36_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', + input_size=(3, 384, 384)), + 'caformer_m36_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), + 'caformer_m36_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_m36_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', + num_classes=21841), + + 'caformer_b36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), + 'caformer_b36_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', + input_size=(3, 384, 384)), + 'caformer_b36_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), + 'caformer_b36_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_b36_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', + num_classes=21841), +} + + +class Downsampling(nn.Module): + """ + Downsampling implemented by a layer of convolution. + """ + def __init__(self, in_channels, out_channels, + kernel_size, stride=1, padding=0, + pre_norm=None, post_norm=None, pre_permute=False): + super().__init__() + self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() + self.pre_permute = pre_permute + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, + stride=stride, padding=padding) + self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() + + def forward(self, x): + x = self.pre_norm(x) + if self.pre_permute: + # if take [B, H, W, C] as input, permute it to [B, C, H, W] + x = x.permute(0, 3, 1, 2) + x = self.conv(x) + x = x.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C] + x = self.post_norm(x) + return x + + +class Scale(nn.Module): + """ + Scale vector by element multiplications. + """ + def __init__(self, dim, init_value=1.0, trainable=True): + super().__init__() + self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable) + + def forward(self, x): + return x * self.scale + + +class SquaredReLU(nn.Module): + """ + Squared ReLU: https://arxiv.org/abs/2109.08668 + """ + def __init__(self, inplace=False): + super().__init__() + self.relu = nn.ReLU(inplace=inplace) + def forward(self, x): + return torch.square(self.relu(x)) + + +class StarReLU(nn.Module): + """ + StarReLU: s * relu(x) ** 2 + b + """ + def __init__(self, scale_value=1.0, bias_value=0.0, + scale_learnable=True, bias_learnable=True, + mode=None, inplace=False): + super().__init__() + self.inplace = inplace + self.relu = nn.ReLU(inplace=inplace) + self.scale = nn.Parameter(scale_value * torch.ones(1), + requires_grad=scale_learnable) + self.bias = nn.Parameter(bias_value * torch.ones(1), + requires_grad=bias_learnable) + def forward(self, x): + return self.scale * self.relu(x)**2 + self.bias + + +class Attention(nn.Module): + """ + Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762. + Modified from timm. + """ + def __init__(self, dim, head_dim=32, num_heads=None, qkv_bias=False, + attn_drop=0., proj_drop=0., proj_bias=False, **kwargs): + super().__init__() + + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + + self.num_heads = num_heads if num_heads else dim // head_dim + if self.num_heads == 0: + self.num_heads = 1 + + self.attention_dim = self.num_heads * self.head_dim + + self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x): + B, H, W, C = x.shape + N = H * W + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.attention_dim) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class RandomMixing(nn.Module): + def __init__(self, num_tokens=196, **kwargs): + super().__init__() + self.random_matrix = nn.parameter.Parameter( + data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), + requires_grad=False) + def forward(self, x): + B, H, W, C = x.shape + x = x.reshape(B, H*W, C) + x = torch.einsum('mn, bnc -> bmc', self.random_matrix, x) + x = x.reshape(B, H, W, C) + return x + + +class LayerNormGeneral(nn.Module): + r""" General LayerNorm for different situations. + + Args: + affine_shape (int, list or tuple): The shape of affine weight and bias. + Usually the affine_shape=C, but in some implementation, like torch.nn.LayerNorm, + the affine_shape is the same as normalized_dim by default. + To adapt to different situations, we offer this argument here. + normalized_dim (tuple or list): Which dims to compute mean and variance. + scale (bool): Flag indicates whether to use scale or not. + bias (bool): Flag indicates whether to use scale or not. + + We give several examples to show how to specify the arguments. + + LayerNorm (https://arxiv.org/abs/1607.06450): + For input shape of (B, *, C) like (B, N, C) or (B, H, W, C), + affine_shape=C, normalized_dim=(-1, ), scale=True, bias=True; + For input shape of (B, C, H, W), + affine_shape=(C, 1, 1), normalized_dim=(1, ), scale=True, bias=True. + + Modified LayerNorm (https://arxiv.org/abs/2111.11418) + that is idental to partial(torch.nn.GroupNorm, num_groups=1): + For input shape of (B, N, C), + affine_shape=C, normalized_dim=(1, 2), scale=True, bias=True; + For input shape of (B, H, W, C), + affine_shape=C, normalized_dim=(1, 2, 3), scale=True, bias=True; + For input shape of (B, C, H, W), + affine_shape=(C, 1, 1), normalized_dim=(1, 2, 3), scale=True, bias=True. + + For the several metaformer baslines, + IdentityFormer, RandFormer and PoolFormerV2 utilize Modified LayerNorm without bias (bias=False); + ConvFormer and CAFormer utilizes LayerNorm without bias (bias=False). + """ + def __init__(self, affine_shape=None, normalized_dim=(-1, ), scale=True, + bias=True, eps=1e-5): + super().__init__() + self.normalized_dim = normalized_dim + self.use_scale = scale + self.use_bias = bias + self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None + self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None + self.eps = eps + + def forward(self, x): + c = x - x.mean(self.normalized_dim, keepdim=True) + s = c.pow(2).mean(self.normalized_dim, keepdim=True) + x = c / torch.sqrt(s + self.eps) + if self.use_scale: + x = x * self.weight + if self.use_bias: + x = x + self.bias + return x + + +class SepConv(nn.Module): + r""" + Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. + """ + def __init__(self, dim, expansion_ratio=2, + act1_layer=StarReLU, act2_layer=nn.Identity, + bias=False, kernel_size=7, padding=3, + **kwargs, ): + super().__init__() + med_channels = int(expansion_ratio * dim) + self.pwconv1 = nn.Linear(dim, med_channels, bias=bias) + self.act1 = act1_layer() + self.dwconv = nn.Conv2d( + med_channels, med_channels, kernel_size=kernel_size, + padding=padding, groups=med_channels, bias=bias) # depthwise conv + self.act2 = act2_layer() + self.pwconv2 = nn.Linear(med_channels, dim, bias=bias) + + def forward(self, x): + x = self.pwconv1(x) + x = self.act1(x) + x = x.permute(0, 3, 1, 2) + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) + x = self.act2(x) + x = self.pwconv2(x) + return x + + +class Pooling(nn.Module): + """ + Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418 + Modfiled for [B, H, W, C] input + """ + def __init__(self, pool_size=3, **kwargs): + super().__init__() + self.pool = nn.AvgPool2d( + pool_size, stride=1, padding=pool_size//2, count_include_pad=False) + + def forward(self, x): + y = x.permute(0, 3, 1, 2) + y = self.pool(y) + y = y.permute(0, 2, 3, 1) + return y - x + + +class Mlp(nn.Module): + """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks. + Mostly copied from timm. + """ + def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs): + super().__init__() + in_features = dim + out_features = out_features or in_features + hidden_features = int(mlp_ratio * in_features) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class MlpHead(nn.Module): + """ MLP classification head + """ + def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=SquaredReLU, + norm_layer=nn.LayerNorm, head_dropout=0., bias=True): + super().__init__() + hidden_features = int(mlp_ratio * dim) + self.fc1 = nn.Linear(dim, hidden_features, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden_features) + self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) + self.head_dropout = nn.Dropout(head_dropout) + + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.head_dropout(x) + x = self.fc2(x) + return x + + +class MetaFormerBlock(nn.Module): + """ + Implementation of one MetaFormer block. + """ + def __init__(self, dim, + token_mixer=nn.Identity, mlp=Mlp, + norm_layer=nn.LayerNorm, + drop=0., drop_path=0., + layer_scale_init_value=None, res_scale_init_value=None + ): + + super().__init__() + + self.norm1 = norm_layer(dim) + self.token_mixer = token_mixer(dim=dim, drop=drop) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) \ + if layer_scale_init_value else nn.Identity() + self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) \ + if res_scale_init_value else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp(dim=dim, drop=drop) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) \ + if layer_scale_init_value else nn.Identity() + self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) \ + if res_scale_init_value else nn.Identity() + + def forward(self, x): + x = self.res_scale1(x) + \ + self.layer_scale1( + self.drop_path1( + self.token_mixer(self.norm1(x)) + ) + ) + x = self.res_scale2(x) + \ + self.layer_scale2( + self.drop_path2( + self.mlp(self.norm2(x)) + ) + ) + return x + + +r""" +downsampling (stem) for the first stage is a layer of conv with k7, s4 and p2 +downsamplings for the last 3 stages is a layer of conv with k3, s2 and p1 +DOWNSAMPLE_LAYERS_FOUR_STAGES format: [Downsampling, Downsampling, Downsampling, Downsampling] +use `partial` to specify some arguments +""" +DOWNSAMPLE_LAYERS_FOUR_STAGES = [partial(Downsampling, + kernel_size=7, stride=4, padding=2, + post_norm=partial(LayerNormGeneral, bias=False, eps=1e-6) + )] + \ + [partial(Downsampling, + kernel_size=3, stride=2, padding=1, + pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=True + )]*3 + + +class MetaFormer(nn.Module): + r""" MetaFormer + A PyTorch impl of : `MetaFormer Baselines for Vision` - + https://arxiv.org/abs/2210.13452 + + Args: + in_chans (int): Number of input image channels. Default: 3. + num_classes (int): Number of classes for classification head. Default: 1000. + depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2]. + dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512]. + downsample_layers: (list or tuple): Downsampling layers before each stage. + token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity. + mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp. + norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False). + drop_path_rate (float): Stochastic depth rate. Default: 0. + head_dropout (float): dropout for MLP classifier. Default: 0. + layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: None. + None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239. + res_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: [None, None, 1.0, 1.0]. + None means not use the layer scale. From: https://arxiv.org/abs/2110.09456. + output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6). + head_fn: classification head. Default: nn.Linear. + """ + def __init__(self, in_chans=3, num_classes=1000, + depths=[2, 2, 6, 2], + dims=[64, 128, 320, 512], + downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, + token_mixers=nn.Identity, + mlps=Mlp, + norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False), + drop_path_rate=0., + head_dropout=0.0, + layer_scale_init_values=None, + res_scale_init_values=[None, None, 1.0, 1.0], + output_norm=partial(nn.LayerNorm, eps=1e-6), + head_fn=nn.Linear, + **kwargs, + ): + super().__init__() + self.num_classes = num_classes + + if not isinstance(depths, (list, tuple)): + depths = [depths] # it means the model has only one stage + if not isinstance(dims, (list, tuple)): + dims = [dims] + + num_stage = len(depths) + self.num_stage = num_stage + + if not isinstance(downsample_layers, (list, tuple)): + downsample_layers = [downsample_layers] * num_stage + down_dims = [in_chans] + dims + self.downsample_layers = nn.ModuleList( + [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)] + ) + + if not isinstance(token_mixers, (list, tuple)): + token_mixers = [token_mixers] * num_stage + + if not isinstance(mlps, (list, tuple)): + mlps = [mlps] * num_stage + + if not isinstance(norm_layers, (list, tuple)): + norm_layers = [norm_layers] * num_stage + + dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + if not isinstance(layer_scale_init_values, (list, tuple)): + layer_scale_init_values = [layer_scale_init_values] * num_stage + if not isinstance(res_scale_init_values, (list, tuple)): + res_scale_init_values = [res_scale_init_values] * num_stage + + self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks + cur = 0 + for i in range(num_stage): + stage = nn.Sequential( + *[MetaFormerBlock(dim=dims[i], + token_mixer=token_mixers[i], + mlp=mlps[i], + norm_layer=norm_layers[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_values[i], + res_scale_init_value=res_scale_init_values[i], + ) for j in range(depths[i])] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = output_norm(dims[-1]) + + if head_dropout > 0.0: + self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout) + else: + self.head = head_fn(dims[-1], num_classes) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'norm'} + + def forward_features(self, x): + for i in range(self.num_stage): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + return self.norm(x.mean([1, 2])) # (B, H, W, C) -> (B, C) + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + + +@register_model +def identityformer_s12(pretrained=False, **kwargs): + model = MetaFormer( + depths=[2, 2, 6, 2], + dims=[64, 128, 320, 512], + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['identityformer_s12'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def identityformer_s24(pretrained=False, **kwargs): + model = MetaFormer( + depths=[4, 4, 12, 4], + dims=[64, 128, 320, 512], + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['identityformer_s24'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def identityformer_s36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[6, 6, 18, 6], + dims=[64, 128, 320, 512], + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['identityformer_s36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def identityformer_m36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[6, 6, 18, 6], + dims=[96, 192, 384, 768], + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['identityformer_m36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def identityformer_m48(pretrained=False, **kwargs): + model = MetaFormer( + depths=[8, 8, 24, 8], + dims=[96, 192, 384, 768], + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['identityformer_m48'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def randformer_s12(pretrained=False, **kwargs): + model = MetaFormer( + depths=[2, 2, 6, 2], + dims=[64, 128, 320, 512], + token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['randformer_s12'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def randformer_s24(pretrained=False, **kwargs): + model = MetaFormer( + depths=[4, 4, 12, 4], + dims=[64, 128, 320, 512], + token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['randformer_s24'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def randformer_s36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[6, 6, 18, 6], + dims=[64, 128, 320, 512], + token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['randformer_s36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def randformer_m36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[6, 6, 18, 6], + dims=[96, 192, 384, 768], + token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['randformer_m36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def randformer_m48(pretrained=False, **kwargs): + model = MetaFormer( + depths=[8, 8, 24, 8], + dims=[96, 192, 384, 768], + token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['randformer_m48'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + + +@register_model +def poolformerv2_s12(pretrained=False, **kwargs): + model = MetaFormer( + depths=[2, 2, 6, 2], + dims=[64, 128, 320, 512], + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['poolformerv2_s12'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def poolformerv2_s24(pretrained=False, **kwargs): + model = MetaFormer( + depths=[4, 4, 12, 4], + dims=[64, 128, 320, 512], + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['poolformerv2_s24'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def poolformerv2_s36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[6, 6, 18, 6], + dims=[64, 128, 320, 512], + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['poolformerv2_s36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def poolformerv2_m36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[6, 6, 18, 6], + dims=[96, 192, 384, 768], + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['poolformerv2_m36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def poolformerv2_m48(pretrained=False, **kwargs): + model = MetaFormer( + depths=[8, 8, 24, 8], + dims=[96, 192, 384, 768], + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['poolformerv2_m48'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s18(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s18'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s18_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s18_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s18_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s18_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s18_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s18_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s18_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s18_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s36_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s36_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s36_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s36_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s36_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s36_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s36_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s36_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_m36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_m36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_m36_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_m36_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_m36_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_m36_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_m36_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_m36_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_m36_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_m36_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_b36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_b36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_b36_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_b36_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_b36_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_b36_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_b36_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_b36_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_b36_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_b36_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s18(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s18'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s18_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s18_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s18_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s18_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s18_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s18_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s18_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s18_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s36_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s36_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s36_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s36_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s36_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s36_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s36_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s36_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_m36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_m36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_m36_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_m36_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_m36_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_m36_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_m36_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_m36_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_m364_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_m364_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_b36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_b36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_b36_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_b36_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_b36_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_b36_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_b36_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_b36_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_b36_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_b36_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model \ No newline at end of file From 1f3a661160036d0d208e9d1e2b457c32f11ed09f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 21:07:14 -0800 Subject: [PATCH 002/102] Update metaformers.py --- timm/models/metaformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 05adde19..92d7a0c5 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -20,10 +20,10 @@ Some implementations are modified from timm (https://github.com/rwightman/pytorc from functools import partial import torch import torch.nn as nn -from timm.models.layers import trunc_normal_, DropPath +from timm.layers import trunc_normal_, DropPath from timm.models.registry import register_model from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.layers.helpers import to_2tuple +from timm.layers.helpers import to_2tuple def _cfg(url='', **kwargs): From d17bcb10a56dc95ed7fbd5cb6bdd49a2ed4bd3b6 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 21:14:21 -0800 Subject: [PATCH 003/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 92d7a0c5..a24e3377 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -21,7 +21,7 @@ from functools import partial import torch import torch.nn as nn from timm.layers import trunc_normal_, DropPath -from timm.models.registry import register_model +from ._registry import register_model from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers.helpers import to_2tuple From 01f671ed080c2cb6e17089fc593436d05962fbbb Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 21:19:00 -0800 Subject: [PATCH 004/102] Update metaformers.py --- timm/models/metaformers.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index a24e3377..0a601783 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -347,18 +347,16 @@ class LayerNormGeneral(nn.Module): self.normalized_dim = normalized_dim self.use_scale = scale self.use_bias = bias - self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None - self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None + self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else torch.ones(affine_shape) + self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else torch.zeros(affine_shape) self.eps = eps def forward(self, x): c = x - x.mean(self.normalized_dim, keepdim=True) s = c.pow(2).mean(self.normalized_dim, keepdim=True) x = c / torch.sqrt(s + self.eps) - if self.use_scale: - x = x * self.weight - if self.use_bias: - x = x + self.bias + x = x * self.weight + x = x + self.bias return x From c7e1819ca5520e7180ba87902e923ffaf9a6cdcd Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 21:23:56 -0800 Subject: [PATCH 005/102] Update metaformers.py --- timm/models/metaformers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 0a601783..ab1b0f0c 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -588,7 +588,7 @@ class MetaFormer(nn.Module): if not isinstance(res_scale_init_values, (list, tuple)): res_scale_init_values = [res_scale_init_values] * num_stage - self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks + stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 for i in range(num_stage): stage = nn.Sequential( @@ -603,8 +603,11 @@ class MetaFormer(nn.Module): ) self.stages.append(stage) cur += depths[i] - + + self.stages = nn.Sequential(*stages) self.norm = output_norm(dims[-1]) + + if head_dropout > 0.0: self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout) From ee16588a3aaab3315be94998bb03b0e79706ae86 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 21:26:02 -0800 Subject: [PATCH 006/102] Update metaformers.py --- timm/models/metaformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index ab1b0f0c..257d4e74 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -568,7 +568,7 @@ class MetaFormer(nn.Module): if not isinstance(downsample_layers, (list, tuple)): downsample_layers = [downsample_layers] * num_stage down_dims = [in_chans] + dims - self.downsample_layers = nn.ModuleList( + downsample_layers = nn.ModuleList( [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)] ) @@ -604,7 +604,7 @@ class MetaFormer(nn.Module): self.stages.append(stage) cur += depths[i] - self.stages = nn.Sequential(*stages) + self.stages = nn.Sequential(zip(*downsample_layers, *stages)) self.norm = output_norm(dims[-1]) From 116345b4d39218a379d2928c44625fd287011a18 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 21:31:01 -0800 Subject: [PATCH 007/102] Update metaformers.py --- timm/models/metaformers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 257d4e74..a7404cf9 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -627,10 +627,11 @@ class MetaFormer(nn.Module): return {'norm'} def forward_features(self, x): - for i in range(self.num_stage): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - return self.norm(x.mean([1, 2])) # (B, H, W, C) -> (B, C) + x = self.stages(x) + x = x.mean([1,2]) # TODO use adaptive pool instead of mean + x = self.norm(x) + # (B, H, W, C) -> (B, C) + return x def forward(self, x): x = self.forward_features(x) From 42741b7cdbe649480e76a82dc8d60b49e32969e1 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 21:38:34 -0800 Subject: [PATCH 008/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index a7404cf9..656f6294 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -601,7 +601,7 @@ class MetaFormer(nn.Module): res_scale_init_value=res_scale_init_values[i], ) for j in range(depths[i])] ) - self.stages.append(stage) + stages.append(stage) cur += depths[i] self.stages = nn.Sequential(zip(*downsample_layers, *stages)) From 0bde1c12181ee0828f2a405b6dd58b73f5edd9c2 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 21:43:56 -0800 Subject: [PATCH 009/102] Update metaformers.py --- timm/models/metaformers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 656f6294..99b9b53f 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -601,10 +601,11 @@ class MetaFormer(nn.Module): res_scale_init_value=res_scale_init_values[i], ) for j in range(depths[i])] ) + stages.append(downsample_layers[i]) stages.append(stage) cur += depths[i] - self.stages = nn.Sequential(zip(*downsample_layers, *stages)) + self.stages = nn.Sequential(*stages) self.norm = output_norm(dims[-1]) From ec202b4d163e4942e62c27837c30d3b679c7526f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 23:17:25 -0800 Subject: [PATCH 010/102] Update metaformers.py --- timm/models/metaformers.py | 57 ++++++++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 12 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 99b9b53f..ff8bc0d5 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -20,12 +20,20 @@ Some implementations are modified from timm (https://github.com/rwightman/pytorc from functools import partial import torch import torch.nn as nn -from timm.layers import trunc_normal_, DropPath -from ._registry import register_model from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import trunc_normal_, DropPath from timm.layers.helpers import to_2tuple +from ._builder import build_model_with_cfg +from ._features import FeatureInfo +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model + +__all__ = ['MetaFormer'] + def _cfg(url='', **kwargs): return { 'url': url, @@ -187,6 +195,7 @@ default_cfgs = { num_classes=21841), } +cfgs_v2 = generate_default_cfgs(default_cfgs) class Downsampling(nn.Module): """ @@ -592,16 +601,17 @@ class MetaFormer(nn.Module): cur = 0 for i in range(num_stage): stage = nn.Sequential( - *[MetaFormerBlock(dim=dims[i], - token_mixer=token_mixers[i], - mlp=mlps[i], - norm_layer=norm_layers[i], - drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_values[i], - res_scale_init_value=res_scale_init_values[i], + downsample_layers[i], + *[MetaFormerBlock( + dim=dims[i], + token_mixer=token_mixers[i], + mlp=mlps[i], + norm_layer=norm_layers[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_values[i], + res_scale_init_value=res_scale_init_values[i], ) for j in range(depths[i])] ) - stages.append(downsample_layers[i]) stages.append(stage) cur += depths[i] @@ -639,8 +649,20 @@ class MetaFormer(nn.Module): x = self.head(x) return x - - +def _create_metaformer(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + + model = build_model_with_cfg( + MetaFormer, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices = out_indices), + **kwargs) + + return model +''' @register_model def identityformer_s12(pretrained=False, **kwargs): model = MetaFormer( @@ -656,6 +678,17 @@ def identityformer_s12(pretrained=False, **kwargs): model.load_state_dict(state_dict) return model +''' + +@register_model +def identityformer_s12(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[2, 2, 6, 2], + dims=[64, 128, 320, 512], + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + return _create_metaformer('identityformer_s12', pretrained=pretrained, **model_kwargs) @register_model def identityformer_s24(pretrained=False, **kwargs): From d90ed530dc34fad8f80fb33618400920d7d86dec Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 23:27:07 -0800 Subject: [PATCH 011/102] Update metaformers.py --- timm/models/metaformers.py | 43 ++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index ff8bc0d5..a90b25be 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -469,14 +469,19 @@ class MetaFormerBlock(nn.Module): Implementation of one MetaFormer block. """ def __init__(self, dim, - token_mixer=nn.Identity, mlp=Mlp, + token_mixer=nn.Identity, + mlp=Mlp, norm_layer=nn.LayerNorm, drop=0., drop_path=0., - layer_scale_init_value=None, res_scale_init_value=None + layer_scale_init_value=None, + res_scale_init_value=None, + downsample = nn.Identity() ): super().__init__() - + + self.downsample = nn.Identity() + self.norm1 = norm_layer(dim) self.token_mixer = token_mixer(dim=dim, drop=drop) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -494,6 +499,7 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): + x = self.downsample(x) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -600,18 +606,18 @@ class MetaFormer(nn.Module): stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 for i in range(num_stage): - stage = nn.Sequential( - downsample_layers[i], - *[MetaFormerBlock( - dim=dims[i], - token_mixer=token_mixers[i], - mlp=mlps[i], - norm_layer=norm_layers[i], - drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_values[i], - res_scale_init_value=res_scale_init_values[i], + stage = nn.Sequential(*[MetaFormerBlock( + dim=dims[i], + token_mixer=token_mixers[i], + mlp=mlps[i], + norm_layer=norm_layers[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_values[i], + res_scale_init_value=res_scale_init_values[i], + downsample = downsample_layers[i] ) for j in range(depths[i])] ) + stages.append(stage) cur += depths[i] @@ -649,6 +655,17 @@ class MetaFormer(nn.Module): x = self.head(x) return x +def checkpoint_filter_fn(state_dict, model): + + import re + out_dict = {} + for k, v in state_dict.items(): + + k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) + out_dict[k] = v + return out_dict + + def _create_metaformer(variant, pretrained=False, **kwargs): default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2)))) out_indices = kwargs.pop('out_indices', default_out_indices) From d484d8618934294dc97e2b29c8861ab8f9da4fb3 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 23:31:38 -0800 Subject: [PATCH 012/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index a90b25be..bedc8738 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -419,7 +419,7 @@ class Mlp(nn.Module): """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks. Mostly copied from timm. """ - def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs): + def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False): super().__init__() in_features = dim out_features = out_features or in_features From fcaf9b8d233fa784135670c8a30425329dc9728f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 23:34:02 -0800 Subject: [PATCH 013/102] Update metaformers.py --- timm/models/metaformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index bedc8738..8857cd41 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -356,8 +356,8 @@ class LayerNormGeneral(nn.Module): self.normalized_dim = normalized_dim self.use_scale = scale self.use_bias = bias - self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else torch.ones(affine_shape) - self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else torch.zeros(affine_shape) + self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else 1 + self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else 0 self.eps = eps def forward(self, x): From f71beadc29aa416363f8c45c34fcf3997b9805ff Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 23:45:49 -0800 Subject: [PATCH 014/102] Update metaformers.py --- timm/models/metaformers.py | 49 +++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 8857cd41..1eea2129 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -316,7 +316,7 @@ class RandomMixing(nn.Module): x = x.reshape(B, H, W, C) return x - +''' class LayerNormGeneral(nn.Module): r""" General LayerNorm for different situations. @@ -367,7 +367,54 @@ class LayerNormGeneral(nn.Module): x = x * self.weight x = x + self.bias return x +''' +class LayerNormGeneral(nn.Module): + r""" General LayerNorm for different situations. + Args: + affine_shape (int, list or tuple): The shape of affine weight and bias. + Usually the affine_shape=C, but in some implementation, like torch.nn.LayerNorm, + the affine_shape is the same as normalized_dim by default. + To adapt to different situations, we offer this argument here. + normalized_dim (tuple or list): Which dims to compute mean and variance. + scale (bool): Flag indicates whether to use scale or not. + bias (bool): Flag indicates whether to use scale or not. + We give several examples to show how to specify the arguments. + LayerNorm (https://arxiv.org/abs/1607.06450): + For input shape of (B, *, C) like (B, N, C) or (B, H, W, C), + affine_shape=C, normalized_dim=(-1, ), scale=True, bias=True; + For input shape of (B, C, H, W), + affine_shape=(C, 1, 1), normalized_dim=(1, ), scale=True, bias=True. + Modified LayerNorm (https://arxiv.org/abs/2111.11418) + that is idental to partial(torch.nn.GroupNorm, num_groups=1): + For input shape of (B, N, C), + affine_shape=C, normalized_dim=(1, 2), scale=True, bias=True; + For input shape of (B, H, W, C), + affine_shape=C, normalized_dim=(1, 2, 3), scale=True, bias=True; + For input shape of (B, C, H, W), + affine_shape=(C, 1, 1), normalized_dim=(1, 2, 3), scale=True, bias=True. + For the several metaformer baslines, + IdentityFormer, RandFormer and PoolFormerV2 utilize Modified LayerNorm without bias (bias=False); + ConvFormer and CAFormer utilizes LayerNorm without bias (bias=False). + """ + def __init__(self, affine_shape=None, normalized_dim=(-1, ), scale=True, + bias=True, eps=1e-5): + super().__init__() + self.normalized_dim = normalized_dim + self.use_scale = scale + self.use_bias = bias + self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None + self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None + self.eps = eps + def forward(self, x): + c = x - x.mean(self.normalized_dim, keepdim=True) + s = c.pow(2).mean(self.normalized_dim, keepdim=True) + x = c / torch.sqrt(s + self.eps) + if self.use_scale: + x = x * self.weight + if self.use_bias: + x = x + self.bias + return class SepConv(nn.Module): r""" From 358c4ae7ecf2abb722bb68329957ac24347643a5 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 23:53:53 -0800 Subject: [PATCH 015/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 1eea2129..be03eec3 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -414,7 +414,7 @@ class LayerNormGeneral(nn.Module): x = x * self.weight if self.use_bias: x = x + self.bias - return + return x class SepConv(nn.Module): r""" From 926d886527de66accf301678ddd5aada5ccb8c9b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Jan 2023 23:55:55 -0800 Subject: [PATCH 016/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index be03eec3..e0189720 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -527,7 +527,7 @@ class MetaFormerBlock(nn.Module): super().__init__() - self.downsample = nn.Identity() + self.downsample = downsample self.norm1 = norm_layer(dim) self.token_mixer = token_mixer(dim=dim, drop=drop) From 7f149f31d4e79b0a9f88e5b80a4f0455cfaec08b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 00:12:52 -0800 Subject: [PATCH 017/102] Update metaformers.py --- timm/models/metaformers.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index e0189720..8098958c 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -521,14 +521,11 @@ class MetaFormerBlock(nn.Module): norm_layer=nn.LayerNorm, drop=0., drop_path=0., layer_scale_init_value=None, - res_scale_init_value=None, - downsample = nn.Identity() + res_scale_init_value=None ): super().__init__() - - self.downsample = downsample - + self.norm1 = norm_layer(dim) self.token_mixer = token_mixer(dim=dim, drop=drop) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -546,7 +543,6 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - x = self.downsample(x) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -653,18 +649,19 @@ class MetaFormer(nn.Module): stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 for i in range(num_stage): - stage = nn.Sequential(*[MetaFormerBlock( - dim=dims[i], - token_mixer=token_mixers[i], - mlp=mlps[i], - norm_layer=norm_layers[i], - drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_values[i], - res_scale_init_value=res_scale_init_values[i], - downsample = downsample_layers[i] - ) for j in range(depths[i])] + stage = nn.Sequential(OrderedDict[ + ('downsample', downsample_layers[i]), + ('blocks', nn.Sequential(*[MetaFormerBlock( + dim=dims[i], + token_mixer=token_mixers[i], + mlp=mlps[i], + norm_layer=norm_layers[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_values[i], + res_scale_init_value=res_scale_init_values[i] + ) for j in range(depths[i])]) + )] ) - stages.append(stage) cur += depths[i] From 7aa3459caf3dd393b28841c8443b3dfdd451f3bb Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 00:14:23 -0800 Subject: [PATCH 018/102] Update metaformers.py --- timm/models/metaformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 8098958c..21f46b5f 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -17,6 +17,7 @@ MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2, ConvFormer and CAFormer. Some implementations are modified from timm (https://github.com/rwightman/pytorch-image-models). """ +from collections import OrderedDict from functools import partial import torch import torch.nn as nn From 944fd549c40c51f9f0c30660f19c662e5f9518c6 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 09:37:49 -0800 Subject: [PATCH 019/102] Update metaformers.py --- timm/models/metaformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 21f46b5f..d12c8735 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -650,7 +650,7 @@ class MetaFormer(nn.Module): stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 for i in range(num_stage): - stage = nn.Sequential(OrderedDict[ + stage = nn.Sequential(OrderedDict([ ('downsample', downsample_layers[i]), ('blocks', nn.Sequential(*[MetaFormerBlock( dim=dims[i], @@ -661,7 +661,7 @@ class MetaFormer(nn.Module): layer_scale_init_value=layer_scale_init_values[i], res_scale_init_value=res_scale_init_values[i] ) for j in range(depths[i])]) - )] + )]) ) stages.append(stage) cur += depths[i] From 1b3318bf191b834b5518272ead9fbcc4ba4a1117 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 10:01:53 -0800 Subject: [PATCH 020/102] Update metaformers.py --- timm/models/metaformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index d12c8735..794fb62c 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -317,7 +317,7 @@ class RandomMixing(nn.Module): x = x.reshape(B, H, W, C) return x -''' + class LayerNormGeneral(nn.Module): r""" General LayerNorm for different situations. @@ -416,7 +416,7 @@ class LayerNormGeneral(nn.Module): if self.use_bias: x = x + self.bias return x - +''' class SepConv(nn.Module): r""" Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. From 4cfecf8acb528f1ee2d349323a263eec1e1a4927 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 11:19:28 -0800 Subject: [PATCH 021/102] Update metaformers.py --- timm/models/metaformers.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 794fb62c..2ba67b0a 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -646,7 +646,10 @@ class MetaFormer(nn.Module): layer_scale_init_values = [layer_scale_init_values] * num_stage if not isinstance(res_scale_init_values, (list, tuple)): res_scale_init_values = [res_scale_init_values] * num_stage - + + self.grad_checkpointing = False + self.feature_info = [] + stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 for i in range(num_stage): @@ -665,6 +668,7 @@ class MetaFormer(nn.Module): ) stages.append(stage) cur += depths[i] + self.feature_info += [dict(num_chs=dims[stage_id], reduction=2, module=f'stages.{stage_id}')] self.stages = nn.Sequential(*stages) self.norm = output_norm(dims[-1]) @@ -687,17 +691,26 @@ class MetaFormer(nn.Module): @torch.jit.ignore def no_weight_decay(self): return {'norm'} - - def forward_features(self, x): - x = self.stages(x) + + def forward_head(self, x, pre_logits: bool = False): + if pre_logits: + return x + x = x.mean([1,2]) # TODO use adaptive pool instead of mean x = self.norm(x) # (B, H, W, C) -> (B, C) + x = self.head(x) + return x + + def forward_features(self, x): + x = self.stages(x) + + return x def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = forward_head(x) return x def checkpoint_filter_fn(state_dict, model): From 1145720bec9b0facc3b26b1421c288a9f71eb7ae Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 11:21:58 -0800 Subject: [PATCH 022/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 2ba67b0a..0e9a58cb 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -668,7 +668,7 @@ class MetaFormer(nn.Module): ) stages.append(stage) cur += depths[i] - self.feature_info += [dict(num_chs=dims[stage_id], reduction=2, module=f'stages.{stage_id}')] + self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) self.norm = output_norm(dims[-1]) From 0aa784027eaf3e56c9c7e40faa4e5b6df0d1fb4a Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 11:22:55 -0800 Subject: [PATCH 023/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 0e9a58cb..fe5622c6 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -710,7 +710,7 @@ class MetaFormer(nn.Module): def forward(self, x): x = self.forward_features(x) - x = forward_head(x) + x = self.forward_head(x) return x def checkpoint_filter_fn(state_dict, model): From bc37dec29b834ca3aae4ab649564b8382be39f19 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 11:51:19 -0800 Subject: [PATCH 024/102] Update metaformers.py --- timm/models/metaformers.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index fe5622c6..9bb85c44 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -197,7 +197,7 @@ default_cfgs = { } cfgs_v2 = generate_default_cfgs(default_cfgs) - +''' class Downsampling(nn.Module): """ Downsampling implemented by a layer of convolution. @@ -221,7 +221,26 @@ class Downsampling(nn.Module): x = x.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C] x = self.post_norm(x) return x +''' +class Downsampling(nn.Module): + """ + Downsampling implemented by a layer of convolution. + """ + def __init__(self, in_channels, out_channels, + kernel_size, stride=1, padding=0, + pre_norm=None, post_norm=None): + super().__init__() + self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() + self.pre_permute = pre_permute + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, + stride=stride, padding=padding) + self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() + def forward(self, x): + x = self.pre_norm(x) + x = self.conv(x) + x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x class Scale(nn.Module): """ @@ -544,6 +563,8 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): + B, C, H, W = x.shape + x = x.view(B, H, W, C) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -556,6 +577,7 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) + x = x.view(B, C, H, W) return x From 4e682f1fda674f9dd5bc8afc662eec3831ada8da Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 11:52:13 -0800 Subject: [PATCH 025/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 9bb85c44..e694fd67 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -228,7 +228,7 @@ class Downsampling(nn.Module): """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, - pre_norm=None, post_norm=None): + pre_norm=None, post_norm=None, pre_permute = False): super().__init__() self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() self.pre_permute = pre_permute From ddfdb543dc445b6d1e0885d53ccb1ba3fc15dd80 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 11:53:20 -0800 Subject: [PATCH 026/102] Update metaformers.py --- timm/models/metaformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index e694fd67..5afd653f 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -231,7 +231,6 @@ class Downsampling(nn.Module): pre_norm=None, post_norm=None, pre_permute = False): super().__init__() self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() - self.pre_permute = pre_permute self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() From 8568bc7b6a62ad62cdf2ed791a32918edf08954b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 12:13:14 -0800 Subject: [PATCH 027/102] Update metaformers.py --- timm/models/metaformers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 5afd653f..5734ee04 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -236,9 +236,13 @@ class Downsampling(nn.Module): self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() def forward(self, x): + print(x.shape) x = self.pre_norm(x) + print(x.shape) x = self.conv(x) + print(x.shape) x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + print(x.shape) return x class Scale(nn.Module): From 464ae0fe50c4d468fe879ef289b21b0d5dd2d425 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 12:27:57 -0800 Subject: [PATCH 028/102] Update metaformers.py --- timm/models/metaformers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 5734ee04..fb950f61 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -197,7 +197,7 @@ default_cfgs = { } cfgs_v2 = generate_default_cfgs(default_cfgs) -''' + class Downsampling(nn.Module): """ Downsampling implemented by a layer of convolution. @@ -213,13 +213,13 @@ class Downsampling(nn.Module): self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() def forward(self, x): - x = self.pre_norm(x) if self.pre_permute: # if take [B, H, W, C] as input, permute it to [B, C, H, W] x = x.permute(0, 3, 1, 2) + x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x = self.conv(x) - x = x.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C] - x = self.post_norm(x) + x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x ''' class Downsampling(nn.Module): @@ -244,7 +244,7 @@ class Downsampling(nn.Module): x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) print(x.shape) return x - +''' class Scale(nn.Module): """ Scale vector by element multiplications. From bde9554604061e580f8dcd54c06ca2b811696e3c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 12:30:10 -0800 Subject: [PATCH 029/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index fb950f61..f909b079 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -596,7 +596,7 @@ DOWNSAMPLE_LAYERS_FOUR_STAGES = [partial(Downsampling, )] + \ [partial(Downsampling, kernel_size=3, stride=2, padding=1, - pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=True + pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=False )]*3 From cc8575e9d1da617040d28923450a9ee387b6fa11 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 12:33:57 -0800 Subject: [PATCH 030/102] Update metaformers.py --- timm/models/metaformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index f909b079..2e6a59af 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -735,6 +735,7 @@ class MetaFormer(nn.Module): def forward(self, x): x = self.forward_features(x) + print(x.shape) x = self.forward_head(x) return x From 3e3010b65e5e70b62422022f4a745217f9c66ff5 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 12:35:08 -0800 Subject: [PATCH 031/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 2e6a59af..e6c12d7f 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -721,7 +721,7 @@ class MetaFormer(nn.Module): if pre_logits: return x - x = x.mean([1,2]) # TODO use adaptive pool instead of mean + x = x.mean([-1,-2]) # TODO use adaptive pool instead of mean x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x) From 8dbba278b7397aeb49763e2811fcb4d2a78c8440 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Jan 2023 12:37:01 -0800 Subject: [PATCH 032/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index e6c12d7f..d3c3dde0 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -38,7 +38,7 @@ __all__ = ['MetaFormer'] def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', **kwargs From a776d98d3f7f0784238a8ce81fe6acce78220f43 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 10:52:50 -0800 Subject: [PATCH 033/102] Update metaformers.py --- timm/models/metaformers.py | 37 +++++++++++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index d3c3dde0..3aa6ff1f 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -1,3 +1,13 @@ + + +""" + +MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2, +ConvFormer and CAFormer. + +original copyright below +""" + # Copyright 2022 Garena Online Private Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,12 +21,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -""" -MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2, -ConvFormer and CAFormer. -Some implementations are modified from timm (https://github.com/rwightman/pytorch-image-models). -""" from collections import OrderedDict from functools import partial import torch @@ -712,10 +716,27 @@ class MetaFormer(nn.Module): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + print("not implemented") @torch.jit.ignore - def no_weight_decay(self): - return {'norm'} + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes=0, global_pool=None): + if global_pool is not None: + self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() + if num_classes == 0: + self.head.norm = nn.Identity() + self.head.fc = nn.Identity() + else: + if not self.head_norm_first: + norm_layer = type(self.stem[-1]) # obtain type from stem norm + self.head.norm = norm_layer(self.num_features) + self.head.fc = nn.Linear(self.num_features, num_classes) def forward_head(self, x, pre_logits: bool = False): if pre_logits: From 95ec7cf01668a8ae7869b78ff011fa8085a45253 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:01:28 -0800 Subject: [PATCH 034/102] Update metaformers.py --- timm/models/metaformers.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 3aa6ff1f..ef43ded3 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -644,6 +644,9 @@ class MetaFormer(nn.Module): ): super().__init__() self.num_classes = num_classes + self.head_fn = head_fn + self.num_features = dims[-1] + self.head_dropout = head_dropout if not isinstance(depths, (list, tuple)): depths = [depths] # it means the model has only one stage @@ -705,9 +708,9 @@ class MetaFormer(nn.Module): if head_dropout > 0.0: - self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout) + self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) else: - self.head = head_fn(dims[-1], num_classes) + self.head = self.head_fn(self.num_featuers, self.num_classes) self.apply(self._init_weights) @@ -723,20 +726,18 @@ class MetaFormer(nn.Module): @torch.jit.ignore def get_classifier(self): - return self.head.fc + return self.head.fc2 def reset_classifier(self, num_classes=0, global_pool=None): - if global_pool is not None: - self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() + if num_classes == 0: - self.head.norm = nn.Identity() - self.head.fc = nn.Identity() + self.head= nn.Identity() + self.norm = nn.Identity() else: - if not self.head_norm_first: - norm_layer = type(self.stem[-1]) # obtain type from stem norm - self.head.norm = norm_layer(self.num_features) - self.head.fc = nn.Linear(self.num_features, num_classes) + if self.head_dropout > 0.0: + self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) + else: + self.head = self.head_fn(self.num_featuers, self.num_classes) def forward_head(self, x, pre_logits: bool = False): if pre_logits: @@ -756,7 +757,6 @@ class MetaFormer(nn.Module): def forward(self, x): x = self.forward_features(x) - print(x.shape) x = self.forward_head(x) return x From 61e8414ad04d64429a29dac7fca1f51542b68c04 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:02:43 -0800 Subject: [PATCH 035/102] Update metaformers.py --- timm/models/metaformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index ef43ded3..89651e02 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -710,7 +710,7 @@ class MetaFormer(nn.Module): if head_dropout > 0.0: self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) else: - self.head = self.head_fn(self.num_featuers, self.num_classes) + self.head = self.head_fn(self.num_features, self.num_classes) self.apply(self._init_weights) @@ -737,7 +737,7 @@ class MetaFormer(nn.Module): if self.head_dropout > 0.0: self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) else: - self.head = self.head_fn(self.num_featuers, self.num_classes) + self.head = self.head_fn(self.num_features, self.num_classes) def forward_head(self, x, pre_logits: bool = False): if pre_logits: From 199b4438849bc845e181f0924a5c5b5f7efe4a87 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:08:18 -0800 Subject: [PATCH 036/102] Update metaformers.py --- timm/models/metaformers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 89651e02..921a33cc 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -26,7 +26,7 @@ from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, DropPath +from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d from timm.layers.helpers import to_2tuple from ._builder import build_model_with_cfg from ._features import FeatureInfo @@ -640,6 +640,7 @@ class MetaFormer(nn.Module): res_scale_init_values=[None, None, 1.0, 1.0], output_norm=partial(nn.LayerNorm, eps=1e-6), head_fn=nn.Linear, + global_pool = 'avg', **kwargs, ): super().__init__() @@ -705,7 +706,7 @@ class MetaFormer(nn.Module): self.stages = nn.Sequential(*stages) self.norm = output_norm(dims[-1]) - + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if head_dropout > 0.0: self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) @@ -731,7 +732,7 @@ class MetaFormer(nn.Module): def reset_classifier(self, num_classes=0, global_pool=None): if num_classes == 0: - self.head= nn.Identity() + self.head = nn.Identity() self.norm = nn.Identity() else: if self.head_dropout > 0.0: @@ -743,7 +744,7 @@ class MetaFormer(nn.Module): if pre_logits: return x - x = x.mean([-1,-2]) # TODO use adaptive pool instead of mean + x = self.global_pool(x) x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x) From 0a9c2607a0eec7287fe3299f5580f3fdb9ebed86 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:10:48 -0800 Subject: [PATCH 037/102] Update metaformers.py --- timm/models/metaformers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 921a33cc..1a074944 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -731,6 +731,10 @@ class MetaFormer(nn.Module): def reset_classifier(self, num_classes=0, global_pool=None): + if global_pool is not None: + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + + if num_classes == 0: self.head = nn.Identity() self.norm = nn.Identity() From a73e464bc892e74892b1301b82a5964b565aa031 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:12:05 -0800 Subject: [PATCH 038/102] Update metaformers.py --- timm/models/metaformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 1a074944..92a8851f 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -749,6 +749,7 @@ class MetaFormer(nn.Module): return x x = self.global_pool(x) + x = x.flatten() x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x) From 7869dd67691e16e4c24d9be0a37e43d0a1e7922f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:13:17 -0800 Subject: [PATCH 039/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 92a8851f..82e44579 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -749,7 +749,7 @@ class MetaFormer(nn.Module): return x x = self.global_pool(x) - x = x.flatten() + x = x.flatten(1) x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x) From d40d4c8c2aea37ec9257ad982cb745f606069598 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:16:05 -0800 Subject: [PATCH 040/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 82e44579..01857167 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -749,7 +749,7 @@ class MetaFormer(nn.Module): return x x = self.global_pool(x) - x = x.flatten(1) + x = x.squeeze() x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x) From f400e8a3c9fa1ccab61d7a44a0d57ab14392ed2b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:21:43 -0800 Subject: [PATCH 041/102] Update metaformers.py --- timm/models/metaformers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 01857167..1ca8ad7e 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -648,6 +648,7 @@ class MetaFormer(nn.Module): self.head_fn = head_fn self.num_features = dims[-1] self.head_dropout = head_dropout + self.output_norm = output_norm if not isinstance(depths, (list, tuple)): depths = [depths] # it means the model has only one stage @@ -704,7 +705,7 @@ class MetaFormer(nn.Module): self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - self.norm = output_norm(dims[-1]) + self.norm = self.output_norm(self.num_features) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -739,6 +740,7 @@ class MetaFormer(nn.Module): self.head = nn.Identity() self.norm = nn.Identity() else: + self.norm = self.output_norm(self.num_features) if self.head_dropout > 0.0: self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) else: From 1b1b1d83b4bf242c225babd6c0573d177fed6aee Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:24:19 -0800 Subject: [PATCH 042/102] Update metaformers.py --- timm/models/metaformers.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 1ca8ad7e..c31185fa 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -706,13 +706,16 @@ class MetaFormer(nn.Module): self.stages = nn.Sequential(*stages) self.norm = self.output_norm(self.num_features) - + ''' self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if head_dropout > 0.0: self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) else: self.head = self.head_fn(self.num_features, self.num_classes) + + ''' + self.reset_classifier(self.num_classes, global_pool) self.apply(self._init_weights) @@ -742,9 +745,9 @@ class MetaFormer(nn.Module): else: self.norm = self.output_norm(self.num_features) if self.head_dropout > 0.0: - self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) + self.head = self.head_fn(self.num_features, num_classes, head_dropout=self.head_dropout) else: - self.head = self.head_fn(self.num_features, self.num_classes) + self.head = self.head_fn(self.num_features, num_classes) def forward_head(self, x, pre_logits: bool = False): if pre_logits: From 087a4513377d972e2d84eeb023d40ff9f00eb9fe Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:32:10 -0800 Subject: [PATCH 043/102] Update metaformers.py --- timm/models/metaformers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index c31185fa..83c757f2 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -44,7 +44,8 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'classifier': 'head', 'first_conv': 'stages.0.downsample.conv', **kwargs } From 2a9f93c064597cd2c3bb2bfa66c544d2405fb1c4 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:36:18 -0800 Subject: [PATCH 044/102] Update metaformers.py --- timm/models/metaformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 83c757f2..015db95b 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -779,6 +779,7 @@ def checkpoint_filter_fn(state_dict, model): for k, v in state_dict.items(): k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) + k = re.sub(r'([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) out_dict[k] = v return out_dict From 5f2bebd7cafe8af74e458e077b0e4fc66a08c6b3 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 12 Jan 2023 11:37:33 -0800 Subject: [PATCH 045/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 015db95b..0020e789 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -779,7 +779,7 @@ def checkpoint_filter_fn(state_dict, model): for k, v in state_dict.items(): k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) - k = re.sub(r'([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) + k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) out_dict[k] = v return out_dict From 55d4eb78a2a3fa6f9b00ad98c76437d6a0b3fb95 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 07:47:52 -0800 Subject: [PATCH 046/102] merge with poolformer, initial version --- timm/models/metaformers.py | 1070 ++++++++++-------------------------- 1 file changed, 286 insertions(+), 784 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 0020e789..d96d6de9 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -26,7 +26,7 @@ from functools import partial import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d +from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1 from timm.layers.helpers import to_2tuple from ._builder import build_model_with_cfg from ._features import FeatureInfo @@ -45,163 +45,177 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'classifier': 'head', 'first_conv': 'stages.0.downsample.conv', + 'classifier': 'head', 'first_conv': 'patch_embed.conv', **kwargs } -default_cfgs = { - 'identityformer_s12': _cfg( +cfgs_v2 = generate_default_cfgs({ + 'poolformerv1_s12.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar', + crop_pct=0.9), + 'poolformerv1_s24.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar', + crop_pct=0.9), + 'poolformerv1_s36.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s36.pth.tar', + crop_pct=0.9), + 'poolformerv1_m36.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m36.pth.tar', + crop_pct=0.95), + 'poolformerv1_m48.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar', + crop_pct=0.95), + + 'identityformer_s12.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'), - 'identityformer_s24': _cfg( + 'identityformer_s24.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'), - 'identityformer_s36': _cfg( + 'identityformer_s36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'), - 'identityformer_m36': _cfg( + 'identityformer_m36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'), - 'identityformer_m48': _cfg( + 'identityformer_m48.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'), - 'randformer_s12': _cfg( + 'randformer_s12.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'), - 'randformer_s24': _cfg( + 'randformer_s24.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'), - 'randformer_s36': _cfg( + 'randformer_s36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'), - 'randformer_m36': _cfg( + 'randformer_m36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'), - 'randformer_m48': _cfg( + 'randformer_m48.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'), - 'poolformerv2_s12': _cfg( + 'poolformerv2_s12.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'), - 'poolformerv2_s24': _cfg( + 'poolformerv2_s24.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'), - 'poolformerv2_s36': _cfg( + 'poolformerv2_s36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'), - 'poolformerv2_m36': _cfg( + 'poolformerv2_m36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'), - 'poolformerv2_m48': _cfg( + 'poolformerv2_m48.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'), - 'convformer_s18': _cfg( + 'convformer_s18.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), - 'convformer_s18_384': _cfg( + 'convformer_s18.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', input_size=(3, 384, 384)), - 'convformer_s18_in21ft1k': _cfg( + 'convformer_s18.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), - 'convformer_s18_384_in21ft1k': _cfg( + 'convformer_s18.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'convformer_s18_in21k': _cfg( + 'convformer_s18.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', num_classes=21841), - 'convformer_s36': _cfg( + 'convformer_s36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), - 'convformer_s36_384': _cfg( + 'convformer_s36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', input_size=(3, 384, 384)), - 'convformer_s36_in21ft1k': _cfg( + 'convformer_s36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), - 'convformer_s36_384_in21ft1k': _cfg( + 'convformer_s36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'convformer_s36_in21k': _cfg( + 'convformer_s36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', num_classes=21841), - 'convformer_m36': _cfg( + 'convformer_m36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), - 'convformer_m36_384': _cfg( + 'convformer_m36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', input_size=(3, 384, 384)), - 'convformer_m36_in21ft1k': _cfg( + 'convformer_m36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), - 'convformer_m36_384_in21ft1k': _cfg( + 'convformer_m36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'convformer_m36_in21k': _cfg( + 'convformer_m36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', num_classes=21841), - 'convformer_b36': _cfg( + 'convformer_b36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), - 'convformer_b36_384': _cfg( + 'convformer_b36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', input_size=(3, 384, 384)), - 'convformer_b36_in21ft1k': _cfg( + 'convformer_b36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), - 'convformer_b36_384_in21ft1k': _cfg( + 'convformer_b36_384.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'convformer_b36_in21k': _cfg( + 'convformer_b36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', num_classes=21841), - 'caformer_s18': _cfg( + 'caformer_s18.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), - 'caformer_s18_384': _cfg( + 'caformer_s18.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', input_size=(3, 384, 384)), - 'caformer_s18_in21ft1k': _cfg( + 'caformer_s18.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), - 'caformer_s18_384_in21ft1k': _cfg( + 'caformer_s18.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'caformer_s18_in21k': _cfg( + 'caformer_s18.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', num_classes=21841), - 'caformer_s36': _cfg( + 'caformer_s36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), - 'caformer_s36_384': _cfg( + 'caformer_s36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', input_size=(3, 384, 384)), - 'caformer_s36_in21ft1k': _cfg( + 'caformer_s36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), - 'caformer_s36_384_in21ft1k': _cfg( + 'caformer_s36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'caformer_s36_in21k': _cfg( + 'caformer_s36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', num_classes=21841), - 'caformer_m36': _cfg( + 'caformer_m36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), - 'caformer_m36_384': _cfg( + 'caformer_m36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', input_size=(3, 384, 384)), - 'caformer_m36_in21ft1k': _cfg( + 'caformer_m36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), - 'caformer_m36_384_in21ft1k': _cfg( + 'caformer_m36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'caformer_m36_in21k': _cfg( + 'caformer_m36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', num_classes=21841), - 'caformer_b36': _cfg( + 'caformer_b36.sail_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), - 'caformer_b36_384': _cfg( + 'caformer_b36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', input_size=(3, 384, 384)), - 'caformer_b36_in21ft1k': _cfg( + 'caformer_b36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), - 'caformer_b36_384_in21ft1k': _cfg( + 'caformer_b36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', input_size=(3, 384, 384)), - 'caformer_b36_in21k': _cfg( + 'caformer_b36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', num_classes=21841), -} - -cfgs_v2 = generate_default_cfgs(default_cfgs) +}) class Downsampling(nn.Module): """ @@ -290,6 +304,11 @@ class StarReLU(nn.Module): def forward(self, x): return self.scale * self.relu(x)**2 + self.bias +class Conv2dChannelsLast(nn.Conv2d): + def forward(self, x): + x = x.permute(0, 3, 1, 2) + return self._conv_forward(x, self.weight, self.bias).permute(0, 2, 3, 1) + class Attention(nn.Module): """ @@ -492,19 +511,28 @@ class Pooling(nn.Module): class Mlp(nn.Module): """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks. - Mostly copied from timm. + Modified from standard timm implementation """ - def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False): + def __init__( + self, + dim, + mlp_ratio=4, + out_features=None, + act_layer=StarReLU, + mlp_fn=nn.Linear, + drop=0., + bias=False + ): super().__init__() in_features = dim out_features = out_features or in_features hidden_features = int(mlp_ratio * in_features) drop_probs = to_2tuple(drop) - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.fc1 = mlp_fn(in_features, hidden_features, bias=bias) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.fc2 = mlp_fn(hidden_features, out_features, bias=bias) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): @@ -516,6 +544,7 @@ class Mlp(nn.Module): return x + class MlpHead(nn.Module): """ MLP classification head """ @@ -539,18 +568,24 @@ class MlpHead(nn.Module): return x + class MetaFormerBlock(nn.Module): """ Implementation of one MetaFormer block. """ - def __init__(self, dim, - token_mixer=nn.Identity, - mlp=Mlp, - norm_layer=nn.LayerNorm, - drop=0., drop_path=0., - layer_scale_init_value=None, - res_scale_init_value=None - ): + def __init__( + self, + dim, + token_mixer=nn.Identity, + mlp=Mlp, + mlp_fn=nn.Linear, + mlp_act=StarReLU, + mlp_bias=False, + norm_layer=nn.LayerNorm, + drop=0., drop_path=0., + layer_scale_init_value=None, + res_scale_init_value=None + ): super().__init__() @@ -563,7 +598,13 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() self.norm2 = norm_layer(dim) - self.mlp = mlp(dim=dim, drop=drop) + self.mlp = mlp( + dim=dim, + drop=drop, + mlp_fn=mlp_fn, + act_layer=mlp_act, + bias=mlp_bias + ) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) \ if layer_scale_init_value else nn.Identity() @@ -573,6 +614,7 @@ class MetaFormerBlock(nn.Module): def forward(self, x): B, C, H, W = x.shape x = x.view(B, H, W, C) + print(x.shape) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -588,23 +630,6 @@ class MetaFormerBlock(nn.Module): x = x.view(B, C, H, W) return x - -r""" -downsampling (stem) for the first stage is a layer of conv with k7, s4 and p2 -downsamplings for the last 3 stages is a layer of conv with k3, s2 and p1 -DOWNSAMPLE_LAYERS_FOUR_STAGES format: [Downsampling, Downsampling, Downsampling, Downsampling] -use `partial` to specify some arguments -""" -DOWNSAMPLE_LAYERS_FOUR_STAGES = [partial(Downsampling, - kernel_size=7, stride=4, padding=2, - post_norm=partial(LayerNormGeneral, bias=False, eps=1e-6) - )] + \ - [partial(Downsampling, - kernel_size=3, stride=2, padding=1, - pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=False - )]*3 - - class MetaFormer(nn.Module): r""" MetaFormer A PyTorch impl of : `MetaFormer Baselines for Vision` - @@ -628,22 +653,29 @@ class MetaFormer(nn.Module): output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6). head_fn: classification head. Default: nn.Linear. """ - def __init__(self, in_chans=3, num_classes=1000, - depths=[2, 2, 6, 2], - dims=[64, 128, 320, 512], - downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, - token_mixers=nn.Identity, - mlps=Mlp, - norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False), - drop_path_rate=0., - head_dropout=0.0, - layer_scale_init_values=None, - res_scale_init_values=[None, None, 1.0, 1.0], - output_norm=partial(nn.LayerNorm, eps=1e-6), - head_fn=nn.Linear, - global_pool = 'avg', - **kwargs, - ): + def __init__( + self, + in_chans=3, + num_classes=1000, + depths=[2, 2, 6, 2], + dims=[64, 128, 320, 512], + #downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, + downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), + token_mixers=nn.Identity, + mlps=Mlp, + mlp_fn=nn.Linear, + mlp_act = StarReLU, + mlp_bias=False, + norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False), + drop_path_rate=0., + head_dropout=0.0, + layer_scale_init_values=None, + res_scale_init_values=[None, None, 1.0, 1.0], + output_norm=partial(nn.LayerNorm, eps=1e-6), + head_fn=nn.Linear, + global_pool = 'avg', + **kwargs, + ): super().__init__() self.num_classes = num_classes self.head_fn = head_fn @@ -656,44 +688,66 @@ class MetaFormer(nn.Module): if not isinstance(dims, (list, tuple)): dims = [dims] - num_stage = len(depths) - self.num_stage = num_stage - + self.num_stages = len(depths) + ''' if not isinstance(downsample_layers, (list, tuple)): - downsample_layers = [downsample_layers] * num_stage + downsample_layers = [downsample_layers] * self.num_stages down_dims = [in_chans] + dims + downsample_layers = nn.ModuleList( - [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)] + [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(self.num_stages)] ) - + ''' if not isinstance(token_mixers, (list, tuple)): - token_mixers = [token_mixers] * num_stage + token_mixers = [token_mixers] * self.num_stages if not isinstance(mlps, (list, tuple)): - mlps = [mlps] * num_stage + mlps = [mlps] * self.num_stages if not isinstance(norm_layers, (list, tuple)): - norm_layers = [norm_layers] * num_stage + norm_layers = [norm_layers] * self.num_stages - dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] if not isinstance(layer_scale_init_values, (list, tuple)): - layer_scale_init_values = [layer_scale_init_values] * num_stage + layer_scale_init_values = [layer_scale_init_values] * self.num_stages if not isinstance(res_scale_init_values, (list, tuple)): - res_scale_init_values = [res_scale_init_values] * num_stage + res_scale_init_values = [res_scale_init_values] * self.num_stages self.grad_checkpointing = False self.feature_info = [] + dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + + self.patch_embed = Downsampling( + in_chans, + dims[0], + kernel_size=7, + stride=4, + padding=2, + post_norm=downsample_norm + ) + stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 - for i in range(num_stage): + for i in range(self.num_stages): stage = nn.Sequential(OrderedDict([ - ('downsample', downsample_layers[i]), + ('downsample', nn.Identity() if i == 0 else Downsampling( + dims[i-1], + dims[i], + kernel_size=3, + stride=2, + padding=1, + pre_norm=downsample_norm, + pre_permute=False + )), ('blocks', nn.Sequential(*[MetaFormerBlock( dim=dims[i], token_mixer=token_mixers[i], mlp=mlps[i], + mlp_fn=mlp_fn, + mlp_act=mlp_act, + mlp_bias=mlp_bias, norm_layer=norm_layers[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_values[i], @@ -762,6 +816,7 @@ class MetaFormer(nn.Module): return x def forward_features(self, x): + x = self.patch_embed(x) x = self.stages(x) @@ -773,13 +828,21 @@ class MetaFormer(nn.Module): return x def checkpoint_filter_fn(state_dict, model): - import re out_dict = {} for k, v in state_dict.items(): - + k = k.replace('proj', 'conv') + k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) + k = k.replace('network.1', 'downsample_layers.1') + k = k.replace('network.3', 'downsample_layers.2') + k = k.replace('network.5', 'downsample_layers.3') + k = k.replace('network.2', 'network.1') + k = k.replace('network.4', 'network.2') + k = k.replace('network.6', 'network.3') + k = k.replace('network', 'stages') k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) + k = k.replace('stages.0.downsample', 'patch_embed') out_dict[k] = v return out_dict @@ -797,23 +860,86 @@ def _create_metaformer(variant, pretrained=False, **kwargs): **kwargs) return model -''' + @register_model -def identityformer_s12(pretrained=False, **kwargs): - model = MetaFormer( +def poolformerv1_s12(pretrained=False, **kwargs): + model_kwargs = dict( depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], - token_mixers=nn.Identity, - norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + downsample_norm=None, + token_mixers=Pooling, + mlp_fn=partial(Conv2dChannelsLast, kernel_size=1), + mlp_act=nn.GELU, + mlp_bias=True, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=True), + layer_scale_init_values=1e-5, + res_scale_init_values=None, **kwargs) - model.default_cfg = default_cfgs['identityformer_s12'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('poolformerv1_s12', pretrained=pretrained, **model_kwargs) -''' +@register_model +def poolformerv1_s24(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[4, 4, 12, 4], + dims=[64, 128, 320, 512], + downsample_norm=None, + token_mixers=Pooling, + mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_act=nn.GELU, + mlp_bias=True, + norm_layers=GroupNorm1, + layer_scale_init_values=1e-5, + res_scale_init_values=None, + **kwargs) + return _create_metaformer('poolformerv1_s24', pretrained=pretrained, **model_kwargs) + +@register_model +def poolformerv1_s36(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[6, 6, 18, 6], + dims=[64, 128, 320, 512], + downsample_norm=None, + token_mixers=Pooling, + mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_act=nn.GELU, + mlp_bias=True, + norm_layers=GroupNorm1, + layer_scale_init_values=1e-6, + res_scale_init_values=None, + **kwargs) + return _create_metaformer('poolformerv1_s36', pretrained=pretrained, **model_kwargs) + +@register_model +def poolformerv1_m36(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[6, 6, 18, 6], + dims=[96, 192, 384, 768], + downsample_norm=None, + token_mixers=Pooling, + mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_act=nn.GELU, + mlp_bias=True, + norm_layers=GroupNorm1, + layer_scale_init_values=1e-6, + res_scale_init_values=None, + **kwargs) + return _create_metaformer('poolformerv1_m36', pretrained=pretrained, **model_kwargs) + +@register_model +def poolformerv1_m48(pretrained=False, **kwargs): + model_kwargs = dict( + depths=[8, 8, 24, 8], + dims=[96, 192, 384, 768], + downsample_norm=None, + token_mixers=Pooling, + mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_act=nn.GELU, + mlp_bias=True, + norm_layers=GroupNorm1, + layer_scale_init_values=1e-6, + res_scale_init_values=None, + **kwargs) + return _create_metaformer('poolformerv1_m48', pretrained=pretrained, **model_kwargs) @register_model def identityformer_s12(pretrained=False, **kwargs): @@ -833,13 +959,7 @@ def identityformer_s24(pretrained=False, **kwargs): token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['identityformer_s24'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('identityformer_s24', pretrained=pretrained, **model_kwargs) @register_model def identityformer_s36(pretrained=False, **kwargs): @@ -849,13 +969,7 @@ def identityformer_s36(pretrained=False, **kwargs): token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['identityformer_s36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('identityformer_s36', pretrained=pretrained, **model_kwargs) @register_model def identityformer_m36(pretrained=False, **kwargs): @@ -865,13 +979,7 @@ def identityformer_m36(pretrained=False, **kwargs): token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['identityformer_m36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('identityformer_m36', pretrained=pretrained, **model_kwargs) @register_model def identityformer_m48(pretrained=False, **kwargs): @@ -881,13 +989,7 @@ def identityformer_m48(pretrained=False, **kwargs): token_mixers=nn.Identity, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['identityformer_m48'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('identityformer_m48', pretrained=pretrained, **model_kwargs) @register_model def randformer_s12(pretrained=False, **kwargs): @@ -897,13 +999,7 @@ def randformer_s12(pretrained=False, **kwargs): token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['randformer_s12'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('randformer_s12', pretrained=pretrained, **model_kwargs) @register_model def randformer_s24(pretrained=False, **kwargs): @@ -913,13 +1009,7 @@ def randformer_s24(pretrained=False, **kwargs): token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['randformer_s24'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('randformer_s24', pretrained=pretrained, **model_kwargs) @register_model def randformer_s36(pretrained=False, **kwargs): @@ -929,13 +1019,7 @@ def randformer_s36(pretrained=False, **kwargs): token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['randformer_s36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('randformer_s36', pretrained=pretrained, **model_kwargs) @register_model def randformer_m36(pretrained=False, **kwargs): @@ -945,13 +1029,7 @@ def randformer_m36(pretrained=False, **kwargs): token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['randformer_m36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('randformer_m36', pretrained=pretrained, **model_kwargs) @register_model def randformer_m48(pretrained=False, **kwargs): @@ -961,14 +1039,7 @@ def randformer_m48(pretrained=False, **kwargs): token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['randformer_m48'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - + return _create_metaformer('randformer_m48', pretrained=pretrained, **model_kwargs) @register_model def poolformerv2_s12(pretrained=False, **kwargs): @@ -978,13 +1049,7 @@ def poolformerv2_s12(pretrained=False, **kwargs): token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['poolformerv2_s12'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('poolformerv2_s12', pretrained=pretrained, **model_kwargs) @register_model def poolformerv2_s24(pretrained=False, **kwargs): @@ -994,12 +1059,8 @@ def poolformerv2_s24(pretrained=False, **kwargs): token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['poolformerv2_s24'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('poolformerv2_s24', pretrained=pretrained, **model_kwargs) + @register_model @@ -1010,12 +1071,8 @@ def poolformerv2_s36(pretrained=False, **kwargs): token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['poolformerv2_s36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('poolformerv2_s36', pretrained=pretrained, **model_kwargs) + @register_model @@ -1026,12 +1083,7 @@ def poolformerv2_m36(pretrained=False, **kwargs): token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['poolformerv2_m36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('poolformerv2_m36', pretrained=pretrained, **model_kwargs) @register_model @@ -1042,92 +1094,21 @@ def poolformerv2_m48(pretrained=False, **kwargs): token_mixers=Pooling, norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), **kwargs) - model.default_cfg = default_cfgs['poolformerv2_m48'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('poolformerv2_m48', pretrained=pretrained, **model_kwargs) -@register_model -def convformer_s18(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s18'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - @register_model -def convformer_s18_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s18_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_s18_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s18_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_s18_384_in21ft1k(pretrained=False, **kwargs): +def convformer_s18(pretrained=False, **kwargs): model = MetaFormer( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['convformer_s18_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('convformer_s18', pretrained=pretrained, **model_kwargs) -@register_model -def convformer_s18_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s18_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model @register_model @@ -1138,76 +1119,7 @@ def convformer_s36(pretrained=False, **kwargs): token_mixers=SepConv, head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['convformer_s36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_s36_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s36_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_s36_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s36_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_s36_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s36_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_s36_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_s36_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('convformer_s36', pretrained=pretrained, **model_kwargs) @register_model @@ -1218,76 +1130,8 @@ def convformer_m36(pretrained=False, **kwargs): token_mixers=SepConv, head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['convformer_m36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('convformer_m36', pretrained=pretrained, **model_kwargs) -@register_model -def convformer_m36_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_m36_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_m36_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_m36_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_m36_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_m36_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_m36_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_m36_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model @register_model @@ -1298,76 +1142,9 @@ def convformer_b36(pretrained=False, **kwargs): token_mixers=SepConv, head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['convformer_b36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_b36_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_b36_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_b36_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_b36_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('convformer_b36', pretrained=pretrained, **model_kwargs) -@register_model -def convformer_b36_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_b36_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def convformer_b36_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=SepConv, - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['convformer_b36_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model @register_model @@ -1378,76 +1155,8 @@ def caformer_s18(pretrained=False, **kwargs): token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['caformer_s18'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s18_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s18_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s18_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s18_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - + return _create_metaformer('caformer_s18', pretrained=pretrained, **model_kwargs) -@register_model -def caformer_s18_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s18_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s18_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 3, 9, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s18_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model @register_model @@ -1458,76 +1167,7 @@ def caformer_s36(pretrained=False, **kwargs): token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['caformer_s36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s36_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s36_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s36_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s36_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s36_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s36_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_s36_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[64, 128, 320, 512], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_s36_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('caformer_s36', pretrained=pretrained, **model_kwargs) @register_model @@ -1538,76 +1178,7 @@ def caformer_m36(pretrained=False, **kwargs): token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['caformer_m36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_m36_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_m36_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_m36_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_m36_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_m36_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_m36_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_m364_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[96, 192, 384, 576], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_m364_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model + return _create_metaformer('caformer_m36', pretrained=pretrained, **model_kwargs) @register_model @@ -1618,73 +1189,4 @@ def caformer_b36(pretrained=False, **kwargs): token_mixers=[SepConv, SepConv, Attention, Attention], head_fn=MlpHead, **kwargs) - model.default_cfg = default_cfgs['caformer_b36'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_b36_384(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_b36_384'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_b36_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_b36_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_b36_384_in21ft1k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_b36_384_in21ft1k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model - - -@register_model -def caformer_b36_in21k(pretrained=False, **kwargs): - model = MetaFormer( - depths=[3, 12, 18, 3], - dims=[128, 256, 512, 768], - token_mixers=[SepConv, SepConv, Attention, Attention], - head_fn=MlpHead, - **kwargs) - model.default_cfg = default_cfgs['caformer_b36_in21k'] - if pretrained: - state_dict = torch.hub.load_state_dict_from_url( - url= model.default_cfg['url'], map_location="cpu", check_hash=True) - model.load_state_dict(state_dict) - return model \ No newline at end of file + return _create_metaformer('caformer_b36', pretrained=pretrained, **model_kwargs) From eaf54b66afb5365e9de79facbdb6c507328696d3 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 08:01:43 -0800 Subject: [PATCH 047/102] Update metaformers.py --- timm/models/metaformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index d96d6de9..679d7ed6 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -614,7 +614,6 @@ class MetaFormerBlock(nn.Module): def forward(self, x): B, C, H, W = x.shape x = x.view(B, H, W, C) - print(x.shape) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( From 5de2544a8025fbf2f71d4a33b384e2a59de32c76 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 10:49:34 -0800 Subject: [PATCH 048/102] Update metaformers.py --- timm/models/metaformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index d96d6de9..4e63c1b8 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -536,6 +536,7 @@ class Mlp(nn.Module): self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): + print(x) x = self.fc1(x) x = self.act(x) x = self.drop1(x) From bfa6f0962d5b1bfcfe166bf829da2bf64c0cdbbd Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 10:56:03 -0800 Subject: [PATCH 049/102] Update metaformers.py --- timm/models/metaformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 5b7f7ce3..679d7ed6 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -536,7 +536,6 @@ class Mlp(nn.Module): self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): - print(x) x = self.fc1(x) x = self.act(x) x = self.drop1(x) From 09d1ea628dab41011ebe9bf72f699e60ddf9d295 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 10:58:07 -0800 Subject: [PATCH 050/102] Update metaformers.py --- timm/models/metaformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 679d7ed6..ce5e99d3 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -363,7 +363,7 @@ class RandomMixing(nn.Module): x = x.reshape(B, H, W, C) return x - +''' class LayerNormGeneral(nn.Module): r""" General LayerNorm for different situations. @@ -462,7 +462,7 @@ class LayerNormGeneral(nn.Module): if self.use_bias: x = x + self.bias return x -''' + class SepConv(nn.Module): r""" Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. From 473403d9059a9153722b53d1f9d3e72e7720b93c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 11:00:25 -0800 Subject: [PATCH 051/102] Update metaformers.py --- timm/models/metaformers.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index ce5e99d3..7632c69a 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -216,7 +216,7 @@ cfgs_v2 = generate_default_cfgs({ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', num_classes=21841), }) - +''' class Downsampling(nn.Module): """ Downsampling implemented by a layer of convolution. @@ -255,15 +255,15 @@ class Downsampling(nn.Module): self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() def forward(self, x): - print(x.shape) + #print(x.shape) x = self.pre_norm(x) - print(x.shape) + #print(x.shape) x = self.conv(x) - print(x.shape) + #print(x.shape) x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - print(x.shape) + #print(x.shape) return x -''' + class Scale(nn.Module): """ Scale vector by element multiplications. @@ -363,7 +363,7 @@ class RandomMixing(nn.Module): x = x.reshape(B, H, W, C) return x -''' + class LayerNormGeneral(nn.Module): r""" General LayerNorm for different situations. @@ -462,7 +462,7 @@ class LayerNormGeneral(nn.Module): if self.use_bias: x = x + self.bias return x - +''' class SepConv(nn.Module): r""" Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. From 13876ada4c9dfe099bd5388e9660bf7bbdc82ef2 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 11:05:30 -0800 Subject: [PATCH 052/102] Update metaformers.py --- timm/models/metaformers.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 7632c69a..a5fffab9 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -216,7 +216,7 @@ cfgs_v2 = generate_default_cfgs({ url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', num_classes=21841), }) -''' + class Downsampling(nn.Module): """ Downsampling implemented by a layer of convolution. @@ -255,15 +255,15 @@ class Downsampling(nn.Module): self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() def forward(self, x): - #print(x.shape) + print(x.shape) x = self.pre_norm(x) - #print(x.shape) + print(x.shape) x = self.conv(x) - #print(x.shape) + print(x.shape) x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - #print(x.shape) + print(x.shape) return x - +''' class Scale(nn.Module): """ Scale vector by element multiplications. @@ -612,8 +612,9 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - B, C, H, W = x.shape - x = x.view(B, H, W, C) + #B, C, H, W = x.shape + #x = x.view(B, H, W, C) + x = x.permute(0, 2, 3, 1) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -626,7 +627,8 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) - x = x.view(B, C, H, W) + #x = x.view(B, C, H, W) + x = x.permute(0, 3, 1, 2) return x class MetaFormer(nn.Module): From 9bf9d163bc9231212ff1a3c4c520e7fd949cefaa Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 11:08:25 -0800 Subject: [PATCH 053/102] Update metaformers.py --- timm/models/metaformers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index a5fffab9..b7435128 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -832,6 +832,7 @@ def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): + ''' k = k.replace('proj', 'conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') @@ -841,6 +842,7 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('network.4', 'network.2') k = k.replace('network.6', 'network.3') k = k.replace('network', 'stages') + ''' k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed') From bee9576f03f483a269e917da075d96c92faf35f1 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 11:11:19 -0800 Subject: [PATCH 054/102] Update metaformers.py --- timm/models/metaformers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index b7435128..d5bf9692 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -809,11 +809,12 @@ class MetaFormer(nn.Module): if pre_logits: return x - x = self.global_pool(x) - x = x.squeeze() - x = self.norm(x) + #x = self.global_pool(x) + #x = x.squeeze() + #x = self.norm(x) # (B, H, W, C) -> (B, C) - x = self.head(x) + #x = self.head(x) + x=self.head(self.norm(x.mean([1, 2]))) return x def forward_features(self, x): From 808f4a7ebd5ffa0e951e22e9636f1bbb6a01730c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 11:13:14 -0800 Subject: [PATCH 055/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index d5bf9692..1356c769 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -814,7 +814,7 @@ class MetaFormer(nn.Module): #x = self.norm(x) # (B, H, W, C) -> (B, C) #x = self.head(x) - x=self.head(self.norm(x.mean([1, 2]))) + x=self.head(self.norm(x.mean([2, 3]))) return x def forward_features(self, x): From 6b6510f30ab8d153f6fb47d1a3af6f21272d66f1 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 11:49:10 -0800 Subject: [PATCH 056/102] Update metaformers.py --- timm/models/metaformers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 1356c769..5608b368 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -819,7 +819,10 @@ class MetaFormer(nn.Module): def forward_features(self, x): x = self.patch_embed(x) - x = self.stages(x) + #x = self.stages(x) + for i, stage in enumerate(self.stages): + x=stage(x) + print(x) return x From 2fef9006d72824789b733dd8dbc20b21b79add0b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 11:56:52 -0800 Subject: [PATCH 057/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 5608b368..a074d3fe 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -822,7 +822,7 @@ class MetaFormer(nn.Module): #x = self.stages(x) for i, stage in enumerate(self.stages): x=stage(x) - print(x) + print(x[0]) return x From cf28c57fda98ffc26d462f92c07532733ea97645 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 11:58:27 -0800 Subject: [PATCH 058/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index a074d3fe..9ae7701f 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -822,7 +822,7 @@ class MetaFormer(nn.Module): #x = self.stages(x) for i, stage in enumerate(self.stages): x=stage(x) - print(x[0]) + print(x[0][0][0][0]) return x From 3f0b07536764c47af42290a218a15097b5cec1a0 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:00:17 -0800 Subject: [PATCH 059/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 9ae7701f..27964167 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -822,7 +822,7 @@ class MetaFormer(nn.Module): #x = self.stages(x) for i, stage in enumerate(self.stages): x=stage(x) - print(x[0][0][0][0]) + #print(x[0][0][0][0]) return x From 4ed934e00068be576881a87d5218b355ca01a6be Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:14:45 -0800 Subject: [PATCH 060/102] Update metaformers.py --- timm/models/metaformers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 27964167..dc83c185 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -819,11 +819,13 @@ class MetaFormer(nn.Module): def forward_features(self, x): x = self.patch_embed(x) + print('timm') #x = self.stages(x) + ''' for i, stage in enumerate(self.stages): x=stage(x) #print(x[0][0][0][0]) - + ''' return x From 32bede4e279ee1027a63d5b55591153e53c209ec Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:18:20 -0800 Subject: [PATCH 061/102] Update metaformers.py --- timm/models/metaformers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index dc83c185..87018b47 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -809,12 +809,12 @@ class MetaFormer(nn.Module): if pre_logits: return x - #x = self.global_pool(x) - #x = x.squeeze() - #x = self.norm(x) + x = self.global_pool(x) + x = x.squeeze() + x = self.norm(x) # (B, H, W, C) -> (B, C) - #x = self.head(x) - x=self.head(self.norm(x.mean([2, 3]))) + x = self.head(x) + #x=self.head(self.norm(x.mean([2, 3]))) return x def forward_features(self, x): From 2209d0830eed9b87149df4a1f25f65da46d47f4b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:20:54 -0800 Subject: [PATCH 062/102] Update metaformers.py --- timm/models/metaformers.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 87018b47..733d05c9 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -536,6 +536,7 @@ class Mlp(nn.Module): self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): + print(x) x = self.fc1(x) x = self.act(x) x = self.drop1(x) @@ -612,9 +613,8 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - #B, C, H, W = x.shape - #x = x.view(B, H, W, C) - x = x.permute(0, 2, 3, 1) + B, C, H, W = x.shape + x = x.view(B, H, W, C) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -627,8 +627,7 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) - #x = x.view(B, C, H, W) - x = x.permute(0, 3, 1, 2) + x = x.view(B, C, H, W) return x class MetaFormer(nn.Module): @@ -814,31 +813,25 @@ class MetaFormer(nn.Module): x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x) - #x=self.head(self.norm(x.mean([2, 3]))) return x def forward_features(self, x): x = self.patch_embed(x) - print('timm') - #x = self.stages(x) - ''' - for i, stage in enumerate(self.stages): - x=stage(x) - #print(x[0][0][0][0]) - ''' + x = self.stages(x) + return x def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) + print('timm') return x def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): - ''' k = k.replace('proj', 'conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') @@ -848,7 +841,6 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('network.4', 'network.2') k = k.replace('network.6', 'network.3') k = k.replace('network', 'stages') - ''' k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed') From 1d882eb494caf5b5a3a1e9c7442d883e1707998e Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:22:29 -0800 Subject: [PATCH 063/102] Update metaformers.py --- timm/models/metaformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 733d05c9..a108afda 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -536,7 +536,6 @@ class Mlp(nn.Module): self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): - print(x) x = self.fc1(x) x = self.act(x) x = self.drop1(x) From 2916f37f8d862cd4fcbc86c858b23894daec3339 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:27:12 -0800 Subject: [PATCH 064/102] Update metaformers.py --- timm/models/metaformers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index a108afda..6f09dd74 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -816,8 +816,10 @@ class MetaFormer(nn.Module): def forward_features(self, x): x = self.patch_embed(x) - x = self.stages(x) - + #x = self.stages(x) + for i, stage in enumerate(self.stages): + x = stage(x) + print(x[0][0][0][0]) return x From e4f89b2e25064fb4368a83cc5bc2d916f6f3a2cf Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:33:31 -0800 Subject: [PATCH 065/102] Revert "Update metaformers.py" This reverts commit 2916f37f8d862cd4fcbc86c858b23894daec3339. --- timm/models/metaformers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 6f09dd74..a108afda 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -816,10 +816,8 @@ class MetaFormer(nn.Module): def forward_features(self, x): x = self.patch_embed(x) - #x = self.stages(x) - for i, stage in enumerate(self.stages): - x = stage(x) - print(x[0][0][0][0]) + x = self.stages(x) + return x From 4f6bbbfb0e51260de5c68470aff7c69bc34e7b19 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:33:36 -0800 Subject: [PATCH 066/102] Revert "Update metaformers.py" This reverts commit 1d882eb494caf5b5a3a1e9c7442d883e1707998e. --- timm/models/metaformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index a108afda..733d05c9 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -536,6 +536,7 @@ class Mlp(nn.Module): self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): + print(x) x = self.fc1(x) x = self.act(x) x = self.drop1(x) From 2912d8c477c991ae6cf4d241e8feea613db6c7a2 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:33:44 -0800 Subject: [PATCH 067/102] Revert "Update metaformers.py" This reverts commit 2209d0830eed9b87149df4a1f25f65da46d47f4b. --- timm/models/metaformers.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 733d05c9..87018b47 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -536,7 +536,6 @@ class Mlp(nn.Module): self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): - print(x) x = self.fc1(x) x = self.act(x) x = self.drop1(x) @@ -613,8 +612,9 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - B, C, H, W = x.shape - x = x.view(B, H, W, C) + #B, C, H, W = x.shape + #x = x.view(B, H, W, C) + x = x.permute(0, 2, 3, 1) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -627,7 +627,8 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) - x = x.view(B, C, H, W) + #x = x.view(B, C, H, W) + x = x.permute(0, 3, 1, 2) return x class MetaFormer(nn.Module): @@ -813,25 +814,31 @@ class MetaFormer(nn.Module): x = self.norm(x) # (B, H, W, C) -> (B, C) x = self.head(x) + #x=self.head(self.norm(x.mean([2, 3]))) return x def forward_features(self, x): x = self.patch_embed(x) - x = self.stages(x) - + print('timm') + #x = self.stages(x) + ''' + for i, stage in enumerate(self.stages): + x=stage(x) + #print(x[0][0][0][0]) + ''' return x def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) - print('timm') return x def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): + ''' k = k.replace('proj', 'conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') @@ -841,6 +848,7 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('network.4', 'network.2') k = k.replace('network.6', 'network.3') k = k.replace('network', 'stages') + ''' k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed') From a1882204eab13f8bcb85ef226083e71a910b3605 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:33:51 -0800 Subject: [PATCH 068/102] Revert "Update metaformers.py" This reverts commit 32bede4e279ee1027a63d5b55591153e53c209ec. --- timm/models/metaformers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 87018b47..dc83c185 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -809,12 +809,12 @@ class MetaFormer(nn.Module): if pre_logits: return x - x = self.global_pool(x) - x = x.squeeze() - x = self.norm(x) + #x = self.global_pool(x) + #x = x.squeeze() + #x = self.norm(x) # (B, H, W, C) -> (B, C) - x = self.head(x) - #x=self.head(self.norm(x.mean([2, 3]))) + #x = self.head(x) + x=self.head(self.norm(x.mean([2, 3]))) return x def forward_features(self, x): From f985f35b9193a70fdef5ecb8cb9d8d9a5bdfcb2c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:33:56 -0800 Subject: [PATCH 069/102] Revert "Update metaformers.py" This reverts commit 4ed934e00068be576881a87d5218b355ca01a6be. --- timm/models/metaformers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index dc83c185..27964167 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -819,13 +819,11 @@ class MetaFormer(nn.Module): def forward_features(self, x): x = self.patch_embed(x) - print('timm') #x = self.stages(x) - ''' for i, stage in enumerate(self.stages): x=stage(x) #print(x[0][0][0][0]) - ''' + return x From 869e01cec6b27285743d993106f741917ab0efe0 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:34:01 -0800 Subject: [PATCH 070/102] Revert "Update metaformers.py" This reverts commit 3f0b07536764c47af42290a218a15097b5cec1a0. --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 27964167..9ae7701f 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -822,7 +822,7 @@ class MetaFormer(nn.Module): #x = self.stages(x) for i, stage in enumerate(self.stages): x=stage(x) - #print(x[0][0][0][0]) + print(x[0][0][0][0]) return x From 47b718a1c84d57455aa6f81f8a71152f4f6157a0 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 12:34:29 -0800 Subject: [PATCH 071/102] Revert "Update metaformers.py" This reverts commit 2fef9006d72824789b733dd8dbc20b21b79add0b. --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 9ae7701f..5608b368 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -822,7 +822,7 @@ class MetaFormer(nn.Module): #x = self.stages(x) for i, stage in enumerate(self.stages): x=stage(x) - print(x[0][0][0][0]) + print(x) return x From 49bf08ed22cc3ea8556a1e208d555eb5a1d3280c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 13:22:16 -0800 Subject: [PATCH 072/102] Update metaformers.py --- timm/models/metaformers.py | 414 +++++++++++++++++++------------------ 1 file changed, 209 insertions(+), 205 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 679d7ed6..361936d4 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -39,184 +39,6 @@ from ._registry import register_model __all__ = ['MetaFormer'] -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 1.0, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'classifier': 'head', 'first_conv': 'patch_embed.conv', - **kwargs - } - - -cfgs_v2 = generate_default_cfgs({ - 'poolformerv1_s12.sail_in1k': _cfg( - url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar', - crop_pct=0.9), - 'poolformerv1_s24.sail_in1k': _cfg( - url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar', - crop_pct=0.9), - 'poolformerv1_s36.sail_in1k': _cfg( - url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s36.pth.tar', - crop_pct=0.9), - 'poolformerv1_m36.sail_in1k': _cfg( - url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m36.pth.tar', - crop_pct=0.95), - 'poolformerv1_m48.sail_in1k': _cfg( - url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar', - crop_pct=0.95), - - 'identityformer_s12.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'), - 'identityformer_s24.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'), - 'identityformer_s36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'), - 'identityformer_m36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'), - 'identityformer_m48.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'), - - - 'randformer_s12.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'), - 'randformer_s24.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'), - 'randformer_s36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'), - 'randformer_m36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'), - 'randformer_m48.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'), - - 'poolformerv2_s12.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'), - 'poolformerv2_s24.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'), - 'poolformerv2_s36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'), - 'poolformerv2_m36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'), - 'poolformerv2_m48.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'), - - - - 'convformer_s18.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), - 'convformer_s18.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', - input_size=(3, 384, 384)), - 'convformer_s18.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), - 'convformer_s18.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'convformer_s18.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', - num_classes=21841), - - 'convformer_s36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), - 'convformer_s36.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', - input_size=(3, 384, 384)), - 'convformer_s36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), - 'convformer_s36.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'convformer_s36.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', - num_classes=21841), - - 'convformer_m36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), - 'convformer_m36.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', - input_size=(3, 384, 384)), - 'convformer_m36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), - 'convformer_m36.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'convformer_m36.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', - num_classes=21841), - - 'convformer_b36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), - 'convformer_b36.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', - input_size=(3, 384, 384)), - 'convformer_b36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), - 'convformer_b36_384.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'convformer_b36.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', - num_classes=21841), - - - 'caformer_s18.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), - 'caformer_s18.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', - input_size=(3, 384, 384)), - 'caformer_s18.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), - 'caformer_s18.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'caformer_s18.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', - num_classes=21841), - - 'caformer_s36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), - 'caformer_s36.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', - input_size=(3, 384, 384)), - 'caformer_s36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), - 'caformer_s36.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'caformer_s36.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', - num_classes=21841), - - 'caformer_m36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), - 'caformer_m36.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', - input_size=(3, 384, 384)), - 'caformer_m36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), - 'caformer_m36.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'caformer_m36.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', - num_classes=21841), - - 'caformer_b36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), - 'caformer_b36.sail_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', - input_size=(3, 384, 384)), - 'caformer_b36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), - 'caformer_b36.sail_in22k_ft_in1k_384': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', - input_size=(3, 384, 384)), - 'caformer_b36.sail_in22k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', - num_classes=21841), -}) - class Downsampling(nn.Module): """ Downsampling implemented by a layer of convolution. @@ -238,7 +60,10 @@ class Downsampling(nn.Module): x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.conv(x) + + x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + print(x[0][0][0][0]) return x ''' class Downsampling(nn.Module): @@ -760,16 +585,14 @@ class MetaFormer(nn.Module): self.stages = nn.Sequential(*stages) self.norm = self.output_norm(self.num_features) - ''' + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) if head_dropout > 0.0: self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) else: self.head = self.head_fn(self.num_features, self.num_classes) - - ''' - self.reset_classifier(self.num_classes, global_pool) + self.apply(self._init_weights) @@ -816,7 +639,9 @@ class MetaFormer(nn.Module): def forward_features(self, x): x = self.patch_embed(x) - x = self.stages(x) + #x = self.stages(x) + for i, stage in enumerate(self.stages): + x = stage(x) return x @@ -860,6 +685,185 @@ def _create_metaformer(variant, pretrained=False, **kwargs): return model + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 1.0, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'classifier': 'head', 'first_conv': 'patch_embed.conv', + **kwargs + } + +default_cfgs = generate_default_cfgs({ + 'poolformerv1_s12.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s12.pth.tar', + crop_pct=0.9), + 'poolformerv1_s24.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s24.pth.tar', + crop_pct=0.9), + 'poolformerv1_s36.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_s36.pth.tar', + crop_pct=0.9), + 'poolformerv1_m36.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m36.pth.tar', + crop_pct=0.95), + 'poolformerv1_m48.sail_in1k': _cfg( + url='https://github.com/sail-sg/poolformer/releases/download/v1.0/poolformer_m48.pth.tar', + crop_pct=0.95), + + 'identityformer_s12.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'), + 'identityformer_s24.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'), + 'identityformer_s36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'), + 'identityformer_m36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'), + 'identityformer_m48.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'), + + + 'randformer_s12.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'), + 'randformer_s24.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'), + 'randformer_s36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'), + 'randformer_m36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'), + 'randformer_m48.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'), + + 'poolformerv2_s12.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'), + 'poolformerv2_s24.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'), + 'poolformerv2_s36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'), + 'poolformerv2_m36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'), + 'poolformerv2_m48.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'), + + + + 'convformer_s18.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), + 'convformer_s18.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', + input_size=(3, 384, 384)), + 'convformer_s18.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), + 'convformer_s18.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_s18.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', + num_classes=21841), + + 'convformer_s36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), + 'convformer_s36.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', + input_size=(3, 384, 384)), + 'convformer_s36.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), + 'convformer_s36.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_s36.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', + num_classes=21841), + + 'convformer_m36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), + 'convformer_m36.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', + input_size=(3, 384, 384)), + 'convformer_m36.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), + 'convformer_m36.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_m36.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', + num_classes=21841), + + 'convformer_b36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), + 'convformer_b36.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', + input_size=(3, 384, 384)), + 'convformer_b36.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), + 'convformer_b36_384.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_b36.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', + num_classes=21841), + + + 'caformer_s18.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), + 'caformer_s18.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', + input_size=(3, 384, 384)), + 'caformer_s18.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), + 'caformer_s18.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_s18.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', + num_classes=21841), + + 'caformer_s36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), + 'caformer_s36.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', + input_size=(3, 384, 384)), + 'caformer_s36.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), + 'caformer_s36.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_s36.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', + num_classes=21841), + + 'caformer_m36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), + 'caformer_m36.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', + input_size=(3, 384, 384)), + 'caformer_m36.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), + 'caformer_m36.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_m36.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', + num_classes=21841), + + 'caformer_b36.sail_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), + 'caformer_b36.sail_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', + input_size=(3, 384, 384)), + 'caformer_b36.sail_in22k_ft_in1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), + 'caformer_b36.sail_in22k_ft_in1k_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_b36.sail_in22k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', + num_classes=21841), +}) + + @register_model def poolformerv1_s12(pretrained=False, **kwargs): model_kwargs = dict( @@ -952,7 +956,7 @@ def identityformer_s12(pretrained=False, **kwargs): @register_model def identityformer_s24(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[4, 4, 12, 4], dims=[64, 128, 320, 512], token_mixers=nn.Identity, @@ -962,7 +966,7 @@ def identityformer_s24(pretrained=False, **kwargs): @register_model def identityformer_s36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[6, 6, 18, 6], dims=[64, 128, 320, 512], token_mixers=nn.Identity, @@ -972,7 +976,7 @@ def identityformer_s36(pretrained=False, **kwargs): @register_model def identityformer_m36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[6, 6, 18, 6], dims=[96, 192, 384, 768], token_mixers=nn.Identity, @@ -982,7 +986,7 @@ def identityformer_m36(pretrained=False, **kwargs): @register_model def identityformer_m48(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[8, 8, 24, 8], dims=[96, 192, 384, 768], token_mixers=nn.Identity, @@ -992,7 +996,7 @@ def identityformer_m48(pretrained=False, **kwargs): @register_model def randformer_s12(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], @@ -1002,7 +1006,7 @@ def randformer_s12(pretrained=False, **kwargs): @register_model def randformer_s24(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[4, 4, 12, 4], dims=[64, 128, 320, 512], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], @@ -1012,7 +1016,7 @@ def randformer_s24(pretrained=False, **kwargs): @register_model def randformer_s36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[6, 6, 18, 6], dims=[64, 128, 320, 512], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], @@ -1022,7 +1026,7 @@ def randformer_s36(pretrained=False, **kwargs): @register_model def randformer_m36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[6, 6, 18, 6], dims=[96, 192, 384, 768], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], @@ -1032,7 +1036,7 @@ def randformer_m36(pretrained=False, **kwargs): @register_model def randformer_m48(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[8, 8, 24, 8], dims=[96, 192, 384, 768], token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], @@ -1042,7 +1046,7 @@ def randformer_m48(pretrained=False, **kwargs): @register_model def poolformerv2_s12(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], token_mixers=Pooling, @@ -1052,7 +1056,7 @@ def poolformerv2_s12(pretrained=False, **kwargs): @register_model def poolformerv2_s24(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[4, 4, 12, 4], dims=[64, 128, 320, 512], token_mixers=Pooling, @@ -1064,7 +1068,7 @@ def poolformerv2_s24(pretrained=False, **kwargs): @register_model def poolformerv2_s36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[6, 6, 18, 6], dims=[64, 128, 320, 512], token_mixers=Pooling, @@ -1076,7 +1080,7 @@ def poolformerv2_s36(pretrained=False, **kwargs): @register_model def poolformerv2_m36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[6, 6, 18, 6], dims=[96, 192, 384, 768], token_mixers=Pooling, @@ -1087,7 +1091,7 @@ def poolformerv2_m36(pretrained=False, **kwargs): @register_model def poolformerv2_m48(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[8, 8, 24, 8], dims=[96, 192, 384, 768], token_mixers=Pooling, @@ -1099,7 +1103,7 @@ def poolformerv2_m48(pretrained=False, **kwargs): @register_model def convformer_s18(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, @@ -1112,7 +1116,7 @@ def convformer_s18(pretrained=False, **kwargs): @register_model def convformer_s36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=SepConv, @@ -1123,7 +1127,7 @@ def convformer_s36(pretrained=False, **kwargs): @register_model def convformer_m36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=SepConv, @@ -1135,7 +1139,7 @@ def convformer_m36(pretrained=False, **kwargs): @register_model def convformer_b36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=SepConv, @@ -1148,7 +1152,7 @@ def convformer_b36(pretrained=False, **kwargs): @register_model def caformer_s18(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 3, 9, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], @@ -1160,7 +1164,7 @@ def caformer_s18(pretrained=False, **kwargs): @register_model def caformer_s36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 12, 18, 3], dims=[64, 128, 320, 512], token_mixers=[SepConv, SepConv, Attention, Attention], @@ -1171,7 +1175,7 @@ def caformer_s36(pretrained=False, **kwargs): @register_model def caformer_m36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 12, 18, 3], dims=[96, 192, 384, 576], token_mixers=[SepConv, SepConv, Attention, Attention], @@ -1182,7 +1186,7 @@ def caformer_m36(pretrained=False, **kwargs): @register_model def caformer_b36(pretrained=False, **kwargs): - model = MetaFormer( + model_kwargs = dict( depths=[3, 12, 18, 3], dims=[128, 256, 512, 768], token_mixers=[SepConv, SepConv, Attention, Attention], From cd36989a604c86b159331452894c7c33f3fde758 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 17 Jan 2023 20:57:51 -0800 Subject: [PATCH 073/102] Update metaformers.py --- timm/models/metaformers.py | 71 ++++++++++++++------------------------ 1 file changed, 26 insertions(+), 45 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 9aef43a5..0efaedc6 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -60,10 +60,7 @@ class Downsampling(nn.Module): x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.conv(x) - - x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - print(x[0][0][0][0]) return x ''' class Downsampling(nn.Module): @@ -494,10 +491,11 @@ class MetaFormer(nn.Module): mlp_bias=False, norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False), drop_path_rate=0., - head_dropout=0.0, + drop_rate=0.0, layer_scale_init_values=None, res_scale_init_values=[None, None, 1.0, 1.0], output_norm=partial(nn.LayerNorm, eps=1e-6), + head_norm_first=False, head_fn=nn.Linear, global_pool = 'avg', **kwargs, @@ -506,9 +504,8 @@ class MetaFormer(nn.Module): self.num_classes = num_classes self.head_fn = head_fn self.num_features = dims[-1] - self.head_dropout = head_dropout - self.output_norm = output_norm - + self.drop_rate = drop_rate + if not isinstance(depths, (list, tuple)): depths = [depths] # it means the model has only one stage if not isinstance(dims, (list, tuple)): @@ -586,15 +583,16 @@ class MetaFormer(nn.Module): self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - self.norm = self.output_norm(self.num_features) - - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - - if head_dropout > 0.0: - self.head = self.head_fn(self.num_features, self.num_classes, head_dropout=self.head_dropout) - else: - self.head = self.head_fn(self.num_features, self.num_classes) + # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets + # otherwise pool -> norm -> fc, similar to ConvNeXt + self.norm_pre = output_norm(self.num_features) if head_norm_first else nn.Identity() + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', nn.Identity() if head_norm_first else output_norm(self.num_features)), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) self.apply(self._init_weights) @@ -613,40 +611,23 @@ class MetaFormer(nn.Module): return self.head.fc2 def reset_classifier(self, num_classes=0, global_pool=None): - if global_pool is not None: - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - - - if num_classes == 0: - self.head = nn.Identity() - self.norm = nn.Identity() - else: - self.norm = self.output_norm(self.num_features) - if self.head_dropout > 0.0: - self.head = self.head_fn(self.num_features, num_classes, head_dropout=self.head_dropout) - else: - self.head = self.head_fn(self.num_features, num_classes) + self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() + self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_head(self, x, pre_logits: bool = False): - if pre_logits: - return x - - #x = self.global_pool(x) - #x = x.squeeze() - #x = self.norm(x) - # (B, H, W, C) -> (B, C) - #x = self.head(x) - x=self.head(self.norm(x.mean([2, 3]))) - return x + # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( + x = self.head.global_pool(x) + x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x = self.head.flatten(x) + x = self.head.drop(x) + return x if pre_logits else self.head.fc(x) def forward_features(self, x): x = self.patch_embed(x) - #x = self.stages(x) - for i, stage in enumerate(self.stages): - x = stage(x) - - + x = self.stages(x) + x = self.norm_pre(x) return x def forward(self, x): @@ -658,7 +639,6 @@ def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): - ''' k = k.replace('proj', 'conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') @@ -668,10 +648,11 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('network.4', 'network.2') k = k.replace('network.6', 'network.3') k = k.replace('network', 'stages') - ''' k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed') + k = re.sub(r'^head', 'head.fc', k) + k = re.sub(r'^norm', 'head.norm', k) out_dict[k] = v return out_dict From 0c89c7eb2de8515c0baab0b8802f963c715140d2 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 00:16:50 -0800 Subject: [PATCH 074/102] rename model --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 0efaedc6..b44d593b 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -784,7 +784,7 @@ default_cfgs = generate_default_cfgs({ input_size=(3, 384, 384)), 'convformer_b36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), - 'convformer_b36_384.sail_in22k_ft_in1k_384': _cfg( + 'convformer_b36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', input_size=(3, 384, 384)), 'convformer_b36.sail_in22k': _cfg( From 0cf0cc837c996eabafb4c295ae34fa0443f16535 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 08:25:10 -0800 Subject: [PATCH 075/102] Update metaformers.py --- timm/models/metaformers.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index b44d593b..a819421c 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -586,13 +586,22 @@ class MetaFormer(nn.Module): # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, similar to ConvNeXt + # drop removed - if using single fc layer, models have no dropout + # if using MlpHead, dropout is handled by MlpHead + if num_classes > 0: + if self.drop_rate > 0.0: + head = self.head_fn(dims[-1], num_classes, head_dropout=self.drop_rate) + else: + head = self.head_fn(dims[-1], num_classes) + else: + head = nn.Identity() + self.norm_pre = output_norm(self.num_features) if head_norm_first else nn.Identity() self.head = nn.Sequential(OrderedDict([ ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), ('norm', nn.Identity() if head_norm_first else output_norm(self.num_features)), ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), - ('drop', nn.Dropout(self.drop_rate)), - ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) + ('fc', head)])) self.apply(self._init_weights) @@ -608,20 +617,26 @@ class MetaFormer(nn.Module): @torch.jit.ignore def get_classifier(self): - return self.head.fc2 + return self.head.fc def reset_classifier(self, num_classes=0, global_pool=None): if global_pool is not None: self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() - self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + if num_classes > 0: + if self.drop_rate > 0.0: + head = self.head_fn(dims[-1], num_classes, head_dropout=self.drop_rate) + else: + head = self.head_fn(dims[-1], num_classes) + else: + head = nn.Identity() + self.head.fc = head def forward_head(self, x, pre_logits: bool = False): # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( x = self.head.global_pool(x) x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.head.flatten(x) - x = self.head.drop(x) return x if pre_logits else self.head.fc(x) def forward_features(self, x): From 924a64f051a7636157ed69bf5535c207f5ce42c5 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 09:30:16 -0800 Subject: [PATCH 076/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index a819421c..bcc8dff8 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -693,7 +693,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'classifier': 'head', 'first_conv': 'patch_embed.conv', + 'classifier': 'head.fc', 'first_conv': 'patch_embed.conv', **kwargs } From 93ed33bc43cf12f9fbdcf1dc85988add7150640e Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 10:41:22 -0800 Subject: [PATCH 077/102] Update metaformers.py --- timm/models/metaformers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index bcc8dff8..bb3d7278 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -751,7 +751,8 @@ default_cfgs = generate_default_cfgs({ 'convformer_s18.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth', + classifier='head.fc.fc2'), 'convformer_s18.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', input_size=(3, 384, 384)), From 6202d9c898021c70b89ef4008c0b4502e6390e55 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 10:49:32 -0800 Subject: [PATCH 078/102] Update metaformers.py --- timm/models/metaformers.py | 93 ++++++++++++++++++++++---------------- 1 file changed, 54 insertions(+), 39 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index bb3d7278..dda526c2 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -755,114 +755,129 @@ default_cfgs = generate_default_cfgs({ classifier='head.fc.fc2'), 'convformer_s18.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'convformer_s18.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth', + classifier='head.fc.fc2'), 'convformer_s18.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'convformer_s18.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', - num_classes=21841), + classifier='head.fc.fc2', num_classes=21841), 'convformer_s36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth', + classifier='head.fc.fc2'), 'convformer_s36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'convformer_s36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth', + classifier='head.fc.fc2'), 'convformer_s36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'convformer_s36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', - num_classes=21841), + classifier='head.fc.fc2', num_classes=21841), 'convformer_m36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth', + classifier='head.fc.fc2'), 'convformer_m36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'convformer_m36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth', + classifier='head.fc.fc2'), 'convformer_m36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'convformer_m36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', - num_classes=21841), + classifier='head.fc.fc2', num_classes=21841), 'convformer_b36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth', + classifier='head.fc.fc2'), 'convformer_b36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'convformer_b36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth', + classifier='head.fc.fc2'), 'convformer_b36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'convformer_b36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', - num_classes=21841), + classifier='head.fc.fc2', num_classes=21841), 'caformer_s18.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth', + classifier='head.fc.fc2'), 'caformer_s18.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'caformer_s18.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth', + classifier='head.fc.fc2'), 'caformer_s18.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'caformer_s18.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', - num_classes=21841), + classifier='head.fc.fc2', num_classes=21841), 'caformer_s36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth', + classifier='head.fc.fc2'), 'caformer_s36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'caformer_s36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth', + classifier='head.fc.fc2'), 'caformer_s36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'caformer_s36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', - num_classes=21841), + classifier='head.fc.fc2', num_classes=21841), 'caformer_m36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth', + classifier='head.fc.fc2'), 'caformer_m36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'caformer_m36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth', + classifier='head.fc.fc2'), 'caformer_m36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'caformer_m36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', - num_classes=21841), + classifier='head.fc.fc2', num_classes=21841), 'caformer_b36.sail_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth', + classifier='head.fc.fc2'), 'caformer_b36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'caformer_b36.sail_in22k_ft_in1k': _cfg( - url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth', + classifier='head.fc.fc2'), 'caformer_b36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', - input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384)), 'caformer_b36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', - num_classes=21841), + classifier='head.fc.fc2', num_classes=21841), }) From f2c4d6f963588e3e4ba3dd22e00c4b640cc1aa5c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 11:04:30 -0800 Subject: [PATCH 079/102] Update metaformers.py --- timm/models/metaformers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index dda526c2..5d3f7160 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -654,7 +654,7 @@ def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): - k = k.replace('proj', 'conv') + #k = k.replace('proj', 'conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') k = k.replace('network.3', 'downsample_layers.2') @@ -663,6 +663,7 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('network.4', 'network.2') k = k.replace('network.6', 'network.3') k = k.replace('network', 'stages') + k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed') From debf4e91715a458792e5a0b9cea0c25a1ce2052f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 11:07:40 -0800 Subject: [PATCH 080/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 5d3f7160..27d29a0a 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -654,7 +654,7 @@ def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): - #k = k.replace('proj', 'conv') + k = k.replace('patch_embed.proj', 'patch_embed.conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') k = k.replace('network.3', 'downsample_layers.2') From 118017ce3d902ae664e49b979069f1413e8cdf5e Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 11:10:02 -0800 Subject: [PATCH 081/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 27d29a0a..672599a7 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -654,7 +654,7 @@ def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): - k = k.replace('patch_embed.proj', 'patch_embed.conv') + k = k.replace('downsample.proj', 'downsample.conv') k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') k = k.replace('network.3', 'downsample_layers.2') From 3de6859d9b875ab0a9b71bcae8b3ac84478dc115 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 11:17:15 -0800 Subject: [PATCH 082/102] Update metaformers.py --- timm/models/metaformers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 672599a7..c1aa9520 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -654,7 +654,7 @@ def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): - k = k.replace('downsample.proj', 'downsample.conv') + k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k) k = k.replace('network.1', 'downsample_layers.1') k = k.replace('network.3', 'downsample_layers.2') @@ -665,6 +665,7 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('network', 'stages') k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) + k = k.replace('downsample.proj', 'downsample.conv') k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed') k = re.sub(r'^head', 'head.fc', k) From a51d1073f5497dac5d9165d5504eb245cd77c743 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 11:21:18 -0800 Subject: [PATCH 083/102] Update metaformers.py --- timm/models/metaformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index c1aa9520..7a41cf2a 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -666,6 +666,7 @@ def checkpoint_filter_fn(state_dict, model): k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k) k = k.replace('downsample.proj', 'downsample.conv') + k = k.replace('patch_embed.proj', 'patch_embed.conv') k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed') k = re.sub(r'^head', 'head.fc', k) From 36854046040cd4c43c6b2107367826c640c2f0f5 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 11:52:18 -0800 Subject: [PATCH 084/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 7a41cf2a..90e24b92 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -179,7 +179,7 @@ class RandomMixing(nn.Module): data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), requires_grad=False) def forward(self, x): - B, H, W, C = x.shape + B, C, H, W = x.shape x = x.reshape(B, H*W, C) x = torch.einsum('mn, bnc -> bmc', self.random_matrix, x) x = x.reshape(B, H, W, C) From 3c6821cc127c80f3277c6733585e560cb606a1e1 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 11:55:16 -0800 Subject: [PATCH 085/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 90e24b92..771a0319 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -182,7 +182,7 @@ class RandomMixing(nn.Module): B, C, H, W = x.shape x = x.reshape(B, H*W, C) x = torch.einsum('mn, bnc -> bmc', self.random_matrix, x) - x = x.reshape(B, H, W, C) + x = x.reshape(B, C, H, W) return x From 9efb706b251eaa6f412fd5999b58a07e39750d60 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 12:03:11 -0800 Subject: [PATCH 086/102] Update metaformers.py --- timm/models/metaformers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 771a0319..c684f22a 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -180,6 +180,7 @@ class RandomMixing(nn.Module): requires_grad=False) def forward(self, x): B, C, H, W = x.shape + print(H*W) x = x.reshape(B, H*W, C) x = torch.einsum('mn, bnc -> bmc', self.random_matrix, x) x = x.reshape(B, C, H, W) From eecab9eba82c25caf473dbf0d936663e5bff8ab8 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 12:05:16 -0800 Subject: [PATCH 087/102] Update metaformers.py --- timm/models/metaformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index c684f22a..771a0319 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -180,7 +180,6 @@ class RandomMixing(nn.Module): requires_grad=False) def forward(self, x): B, C, H, W = x.shape - print(H*W) x = x.reshape(B, H*W, C) x = torch.einsum('mn, bnc -> bmc', self.random_matrix, x) x = x.reshape(B, C, H, W) From 383c9fd43da4ccdcd74bc76a3b28ba385b5d1044 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 12:12:26 -0800 Subject: [PATCH 088/102] Update metaformers.py --- timm/models/metaformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 771a0319..7a41cf2a 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -179,10 +179,10 @@ class RandomMixing(nn.Module): data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), requires_grad=False) def forward(self, x): - B, C, H, W = x.shape + B, H, W, C = x.shape x = x.reshape(B, H*W, C) x = torch.einsum('mn, bnc -> bmc', self.random_matrix, x) - x = x.reshape(B, C, H, W) + x = x.reshape(B, H, W, C) return x From 53f992723cd36498814ffa6d608cb16f9faa6d77 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 12:25:07 -0800 Subject: [PATCH 089/102] Update metaformers.py --- timm/models/metaformers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 7a41cf2a..5b98cba8 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -175,12 +175,16 @@ class Attention(nn.Module): class RandomMixing(nn.Module): def __init__(self, num_tokens=196, **kwargs): super().__init__() + ''' self.random_matrix = nn.parameter.Parameter( data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), requires_grad=False) + ''' + self.random_matrix = torch.softmax(torch.rand(num_tokens, num_tokens) def forward(self, x): B, H, W, C = x.shape x = x.reshape(B, H*W, C) + # FIXME change to work with arbitrary input sizes x = torch.einsum('mn, bnc -> bmc', self.random_matrix, x) x = x.reshape(B, H, W, C) return x From 9a415a60416028cd5541515ae478ade19472005e Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 12:29:42 -0800 Subject: [PATCH 090/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 5b98cba8..d0c2b401 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -180,7 +180,7 @@ class RandomMixing(nn.Module): data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), requires_grad=False) ''' - self.random_matrix = torch.softmax(torch.rand(num_tokens, num_tokens) + self.random_matrix = torch.softmax(torch.rand(num_tokens, num_tokens)) def forward(self, x): B, H, W, C = x.shape x = x.reshape(B, H*W, C) From 61b4e716ca6b58d82f8afe75472c5455e4525519 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 18 Jan 2023 12:30:37 -0800 Subject: [PATCH 091/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index d0c2b401..fd992209 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -180,7 +180,7 @@ class RandomMixing(nn.Module): data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), requires_grad=False) ''' - self.random_matrix = torch.softmax(torch.rand(num_tokens, num_tokens)) + self.random_matrix = torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1) def forward(self, x): B, H, W, C = x.shape x = x.reshape(B, H*W, C) From 143f8e69b1302011d33283528a4b98d955633448 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Fri, 20 Jan 2023 01:53:36 -0800 Subject: [PATCH 092/102] Update metaformers.py --- timm/models/metaformers.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index fd992209..3b21a209 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -175,12 +175,9 @@ class Attention(nn.Module): class RandomMixing(nn.Module): def __init__(self, num_tokens=196, **kwargs): super().__init__() - ''' self.random_matrix = nn.parameter.Parameter( data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), requires_grad=False) - ''' - self.random_matrix = torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1) def forward(self, x): B, H, W, C = x.shape x = x.reshape(B, H*W, C) @@ -444,13 +441,20 @@ class MetaFormerBlock(nn.Module): x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( - self.token_mixer(self.norm1(x)) + self.token_mixer( + self.norm1( + x.permute(0, 3, 1, 2) + ).permute(0, 2, 3, 1) + ) ) ) x = self.res_scale2(x) + \ self.layer_scale2( self.drop_path2( - self.mlp(self.norm2(x)) + self.mlp(self.norm2( + x.permute(0, 3, 1, 2) + ).permute(0, 2, 3, 1) + ) ) ) #x = x.view(B, C, H, W) From dd57cde1bc3905a4c724184d6bd5509bba4834c4 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Fri, 20 Jan 2023 01:56:03 -0800 Subject: [PATCH 093/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 3b21a209..de5433f6 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -453,7 +453,7 @@ class MetaFormerBlock(nn.Module): self.drop_path2( self.mlp(self.norm2( x.permute(0, 3, 1, 2) - ).permute(0, 2, 3, 1) + )#.permute(0, 2, 3, 1) ) ) ) From 729533e96626cdc5b72aae1f6fb2b12341e35d31 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Fri, 20 Jan 2023 02:01:06 -0800 Subject: [PATCH 094/102] Update metaformers.py --- timm/models/metaformers.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index de5433f6..fbd93ea0 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -441,20 +441,13 @@ class MetaFormerBlock(nn.Module): x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( - self.token_mixer( - self.norm1( - x.permute(0, 3, 1, 2) - ).permute(0, 2, 3, 1) - ) + self.token_mixer(self.norm1(x)) ) ) x = self.res_scale2(x) + \ self.layer_scale2( self.drop_path2( - self.mlp(self.norm2( - x.permute(0, 3, 1, 2) - )#.permute(0, 2, 3, 1) - ) + self.mlp(self.norm2(x)) ) ) #x = x.view(B, C, H, W) @@ -915,10 +908,10 @@ def poolformerv1_s24(pretrained=False, **kwargs): dims=[64, 128, 320, 512], downsample_norm=None, token_mixers=Pooling, - mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_fn=partial(Conv2dChannelsLast, kernel_size=1), mlp_act=nn.GELU, mlp_bias=True, - norm_layers=GroupNorm1, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=True), layer_scale_init_values=1e-5, res_scale_init_values=None, **kwargs) @@ -931,10 +924,10 @@ def poolformerv1_s36(pretrained=False, **kwargs): dims=[64, 128, 320, 512], downsample_norm=None, token_mixers=Pooling, - mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_fn=partial(Conv2dChannelsLast, kernel_size=1), mlp_act=nn.GELU, mlp_bias=True, - norm_layers=GroupNorm1, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=True), layer_scale_init_values=1e-6, res_scale_init_values=None, **kwargs) @@ -947,10 +940,10 @@ def poolformerv1_m36(pretrained=False, **kwargs): dims=[96, 192, 384, 768], downsample_norm=None, token_mixers=Pooling, - mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_fn=partial(Conv2dChannelsLast, kernel_size=1), mlp_act=nn.GELU, mlp_bias=True, - norm_layers=GroupNorm1, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=True), layer_scale_init_values=1e-6, res_scale_init_values=None, **kwargs) @@ -963,10 +956,10 @@ def poolformerv1_m48(pretrained=False, **kwargs): dims=[96, 192, 384, 768], downsample_norm=None, token_mixers=Pooling, - mlp_fn=partial(nn.Conv2d, kernel_size=1), + mlp_fn=partial(Conv2dChannelsLast, kernel_size=1), mlp_act=nn.GELU, mlp_bias=True, - norm_layers=GroupNorm1, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=True), layer_scale_init_values=1e-6, res_scale_init_values=None, **kwargs) From 7c04f6dc759dd302ee355b1955f8a6c0f6194168 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 28 Jan 2023 15:09:12 -0800 Subject: [PATCH 095/102] Update metaformers.py --- timm/models/metaformers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index fbd93ea0..5935b8e4 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -543,7 +543,7 @@ class MetaFormer(nn.Module): dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - self.patch_embed = Downsampling( + self.stem = Downsampling( in_chans, dims[0], kernel_size=7, @@ -670,6 +670,7 @@ def checkpoint_filter_fn(state_dict, model): k = k.replace('patch_embed.proj', 'patch_embed.conv') k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed') + k = k.replace('patch_embed', 'stem') k = re.sub(r'^head', 'head.fc', k) k = re.sub(r'^norm', 'head.norm', k) out_dict[k] = v From 01d02fb40f4c89e47f4cb4d825051b8bf24691d0 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 31 Jan 2023 16:43:19 -0800 Subject: [PATCH 096/102] Update metaformers.py --- timm/models/metaformers.py | 74 ++------------------------------------ 1 file changed, 3 insertions(+), 71 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 5935b8e4..07d150c9 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -21,6 +21,8 @@ original copyright below # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + + from collections import OrderedDict from functools import partial import torch @@ -62,30 +64,8 @@ class Downsampling(nn.Module): x = self.conv(x) x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x -''' -class Downsampling(nn.Module): - """ - Downsampling implemented by a layer of convolution. - """ - def __init__(self, in_channels, out_channels, - kernel_size, stride=1, padding=0, - pre_norm=None, post_norm=None, pre_permute = False): - super().__init__() - self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, - stride=stride, padding=padding) - self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() - def forward(self, x): - print(x.shape) - x = self.pre_norm(x) - print(x.shape) - x = self.conv(x) - print(x.shape) - x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - print(x.shape) - return x -''' + class Scale(nn.Module): """ Scale vector by element multiplications. @@ -237,55 +217,7 @@ class LayerNormGeneral(nn.Module): x = x * self.weight x = x + self.bias return x -''' -class LayerNormGeneral(nn.Module): - r""" General LayerNorm for different situations. - Args: - affine_shape (int, list or tuple): The shape of affine weight and bias. - Usually the affine_shape=C, but in some implementation, like torch.nn.LayerNorm, - the affine_shape is the same as normalized_dim by default. - To adapt to different situations, we offer this argument here. - normalized_dim (tuple or list): Which dims to compute mean and variance. - scale (bool): Flag indicates whether to use scale or not. - bias (bool): Flag indicates whether to use scale or not. - We give several examples to show how to specify the arguments. - LayerNorm (https://arxiv.org/abs/1607.06450): - For input shape of (B, *, C) like (B, N, C) or (B, H, W, C), - affine_shape=C, normalized_dim=(-1, ), scale=True, bias=True; - For input shape of (B, C, H, W), - affine_shape=(C, 1, 1), normalized_dim=(1, ), scale=True, bias=True. - Modified LayerNorm (https://arxiv.org/abs/2111.11418) - that is idental to partial(torch.nn.GroupNorm, num_groups=1): - For input shape of (B, N, C), - affine_shape=C, normalized_dim=(1, 2), scale=True, bias=True; - For input shape of (B, H, W, C), - affine_shape=C, normalized_dim=(1, 2, 3), scale=True, bias=True; - For input shape of (B, C, H, W), - affine_shape=(C, 1, 1), normalized_dim=(1, 2, 3), scale=True, bias=True. - For the several metaformer baslines, - IdentityFormer, RandFormer and PoolFormerV2 utilize Modified LayerNorm without bias (bias=False); - ConvFormer and CAFormer utilizes LayerNorm without bias (bias=False). - """ - def __init__(self, affine_shape=None, normalized_dim=(-1, ), scale=True, - bias=True, eps=1e-5): - super().__init__() - self.normalized_dim = normalized_dim - self.use_scale = scale - self.use_bias = bias - self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None - self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None - self.eps = eps - def forward(self, x): - c = x - x.mean(self.normalized_dim, keepdim=True) - s = c.pow(2).mean(self.normalized_dim, keepdim=True) - x = c / torch.sqrt(s + self.eps) - if self.use_scale: - x = x * self.weight - if self.use_bias: - x = x + self.bias - return x -''' class SepConv(nn.Module): r""" Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. From 8f70650295cf808b529d2a86918562acbf825974 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 31 Jan 2023 16:45:38 -0800 Subject: [PATCH 097/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 07d150c9..7a90b448 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -630,7 +630,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'classifier': 'head.fc', 'first_conv': 'patch_embed.conv', + 'classifier': 'head.fc', 'first_conv': 'stem.conv', **kwargs } From 1fd5f7672d2dec34f018937a5e7bc22d540fa29a Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 31 Jan 2023 17:03:25 -0800 Subject: [PATCH 098/102] Update metaformers.py --- timm/models/metaformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 7a90b448..1cb13141 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -573,7 +573,7 @@ class MetaFormer(nn.Module): return x if pre_logits else self.head.fc(x) def forward_features(self, x): - x = self.patch_embed(x) + x = self.stem(x) x = self.stages(x) x = self.norm_pre(x) return x From 5d9cb3b943039ee597af682b9782309722a0dfd1 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 31 Jan 2023 19:25:47 -0800 Subject: [PATCH 099/102] Update metaformers.py --- timm/models/metaformers.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 1cb13141..30cba4cc 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -546,7 +546,7 @@ class MetaFormer(nn.Module): @torch.jit.ignore def set_grad_checkpointing(self, enable=True): - print("not implemented") + self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self): @@ -574,7 +574,10 @@ class MetaFormer(nn.Module): def forward_features(self, x): x = self.stem(x) - x = self.stages(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) x = self.norm_pre(x) return x @@ -583,6 +586,8 @@ class MetaFormer(nn.Module): x = self.forward_head(x) return x +# FIXME convert to group matcher +# this works but it's long and breaks backwards compatability with weights from the poolformer-only impl def checkpoint_filter_fn(state_dict, model): import re out_dict = {} @@ -817,6 +822,7 @@ default_cfgs = generate_default_cfgs({ classifier='head.fc.fc2', num_classes=21841), }) +# FIXME fully merge poolformerv1, rename to poolformer to succeed poolformer.py @register_model def poolformerv1_s12(pretrained=False, **kwargs): From 5a19034a99d4f4473f94a8124455b03e9c7af06f Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 31 Jan 2023 22:15:08 -0800 Subject: [PATCH 100/102] Update metaformers.py --- timm/models/metaformers.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 30cba4cc..f5e12c5a 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -1,11 +1,10 @@ - - """ +Poolformer from MetaFormer is Actually What You Need for Vision https://arxiv.org/abs/2111.11418 MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2, -ConvFormer and CAFormer. +ConvFormer, and CAFormer as per https://arxiv.org/abs/2210.13452 -original copyright below +Adapted from https://github.com/sail-sg/metaformer, original copyright below """ # Copyright 2022 Garena Online Private Limited @@ -155,6 +154,7 @@ class Attention(nn.Module): class RandomMixing(nn.Module): def __init__(self, num_tokens=196, **kwargs): super().__init__() + # FIXME no grad breaks tests self.random_matrix = nn.parameter.Parameter( data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), requires_grad=False) @@ -367,8 +367,6 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - #B, C, H, W = x.shape - #x = x.view(B, H, W, C) x = x.permute(0, 2, 3, 1) x = self.res_scale1(x) + \ self.layer_scale1( @@ -382,7 +380,6 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) - #x = x.view(B, C, H, W) x = x.permute(0, 3, 1, 2) return x @@ -396,7 +393,6 @@ class MetaFormer(nn.Module): num_classes (int): Number of classes for classification head. Default: 1000. depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2]. dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512]. - downsample_layers: (list or tuple): Downsampling layers before each stage. token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity. mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp. norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False). @@ -415,7 +411,6 @@ class MetaFormer(nn.Module): num_classes=1000, depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], - #downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), token_mixers=nn.Identity, mlps=Mlp, @@ -445,15 +440,7 @@ class MetaFormer(nn.Module): dims = [dims] self.num_stages = len(depths) - ''' - if not isinstance(downsample_layers, (list, tuple)): - downsample_layers = [downsample_layers] * self.num_stages - down_dims = [in_chans] + dims - - downsample_layers = nn.ModuleList( - [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(self.num_stages)] - ) - ''' + if not isinstance(token_mixers, (list, tuple)): token_mixers = [token_mixers] * self.num_stages From e2a9408dd07142bf888bfee1d036dc9734b60fd0 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Fri, 10 Feb 2023 02:09:35 -0800 Subject: [PATCH 101/102] Squashed commit of the following: commit b7696a30a772dbbb2e00d81e7096c24dac97df73 Author: Fredo Guan Date: Fri Feb 10 01:46:44 2023 -0800 Update metaformers.py commit 41fe5c36263b40a6cd7caddb85b10c5d82d48023 Author: Fredo Guan Date: Fri Feb 10 01:03:47 2023 -0800 Update metaformers.py commit a3aee37c35985c01ca07902d860f809b648c612c Author: Fredo Guan Date: Fri Feb 10 00:32:04 2023 -0800 Update metaformers.py commit f938beb81b4f46851d6d6f04ae7a9a74871ee40d Author: Fredo Guan Date: Fri Feb 10 00:24:58 2023 -0800 Update metaformers.py commit 10bde717e51c95cdf20135c8bba77a7a1b00d78c Author: Fredo Guan Date: Sun Feb 5 02:11:28 2023 -0800 Update metaformers.py commit 39274bd45e78b8ead0509367f800121f0d7c25f4 Author: Fredo Guan Date: Sun Feb 5 02:06:58 2023 -0800 Update metaformers.py commit a2329ab8ec00d0ebc00979690293c3887cc44a4c Author: Fredo Guan Date: Sun Feb 5 02:03:34 2023 -0800 Update metaformers.py commit 53b8ce5b8a6b6d828de61788bcc2e6043ebb3081 Author: Fredo Guan Date: Sun Feb 5 02:02:37 2023 -0800 Update metaformers.py commit ab6225b9414f534815958036f6d5a392038d7ab2 Author: Fredo Guan Date: Sun Feb 5 01:04:55 2023 -0800 try NHWC commit 02fcc30eaa67a3c92cae56f3062b2542c32c9283 Author: Fredo Guan Date: Sat Feb 4 23:47:06 2023 -0800 Update metaformers.py commit 366aae93047934bd3d7d37a077e713d424fa429c Author: Fredo Guan Date: Sat Feb 4 23:37:30 2023 -0800 Stem/Downsample rework commit 26a8e481a5cb2a32004a796bc25c6800cd2fb7b7 Author: Fredo Guan Date: Wed Feb 1 07:42:07 2023 -0800 Update metaformers.py commit a913f5d4384aa4b2f62fdab46254ae3772df00ee Author: Fredo Guan Date: Wed Feb 1 07:41:24 2023 -0800 Update metaformers.py --- timm/models/metaformers.py | 204 ++++++++++++++++++++++++++----------- 1 file changed, 143 insertions(+), 61 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index f5e12c5a..36263e34 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -24,8 +24,11 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below from collections import OrderedDict from functools import partial + import torch import torch.nn as nn +from torch import Tensor + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1 from timm.layers.helpers import to_2tuple @@ -40,28 +43,58 @@ from ._registry import register_model __all__ = ['MetaFormer'] + +class Stem(nn.Module): + """ + Stem implemented by a layer of convolution. + Conv2d params constant across all models. + """ + def __init__(self, + in_channels, + out_channels, + norm_layer=None, + ): + super().__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=4, + padding=2 + ) + self.norm = norm_layer(out_channels) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.conv(x) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + # [B, C, H, W] + return x + class Downsampling(nn.Module): """ Downsampling implemented by a layer of convolution. """ - def __init__(self, in_channels, out_channels, - kernel_size, stride=1, padding=0, - pre_norm=None, post_norm=None, pre_permute=False): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + norm_layer=None, + ): super().__init__() - self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() - self.pre_permute = pre_permute - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, - stride=stride, padding=padding) - self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() + self.norm = norm_layer(in_channels) if norm_layer else nn.Identity() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) def forward(self, x): - if self.pre_permute: - # if take [B, H, W, C] as input, permute it to [B, C, H, W] - x = x.permute(0, 3, 1, 2) - x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.conv(x) - x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x @@ -299,7 +332,6 @@ class Mlp(nn.Module): return x - class MlpHead(nn.Module): """ MLP classification head """ @@ -323,7 +355,6 @@ class MlpHead(nn.Module): return x - class MetaFormerBlock(nn.Module): """ Implementation of one MetaFormer block. @@ -367,7 +398,6 @@ class MetaFormerBlock(nn.Module): if res_scale_init_value else nn.Identity() def forward(self, x): - x = x.permute(0, 2, 3, 1) x = self.res_scale1(x) + \ self.layer_scale1( self.drop_path1( @@ -380,6 +410,69 @@ class MetaFormerBlock(nn.Module): self.mlp(self.norm2(x)) ) ) + return x + +class MetaFormerStage(nn.Module): + # implementation of a single metaformer stage + def __init__( + self, + in_chs, + out_chs, + depth=2, + downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), + token_mixer=nn.Identity, + mlp=Mlp, + mlp_fn=nn.Linear, + mlp_act=StarReLU, + mlp_bias=False, + norm_layer=partial(LayerNormGeneral, eps=1e-6, bias=False), + dp_rates=[0.]*2, + layer_scale_init_value=None, + res_scale_init_value=None, + ): + super().__init__() + + self.grad_checkpointing = False + + # don't downsample if in_chs and out_chs are the same + self.downsample = nn.Identity() if in_chs == out_chs else Downsampling( + in_chs, + out_chs, + kernel_size=3, + stride=2, + padding=1, + norm_layer=downsample_norm + ) + + self.blocks = nn.Sequential(*[MetaFormerBlock( + dim=out_chs, + token_mixer=token_mixer, + mlp=mlp, + mlp_fn=mlp_fn, + mlp_act=mlp_act, + mlp_bias=mlp_bias, + norm_layer=norm_layer, + drop_path=dp_rates[i], + layer_scale_init_value=layer_scale_init_value, + res_scale_init_value=res_scale_init_value + ) for i in range(depth)]) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + # Permute to channels-first for feature extraction + def forward(self, x: Tensor): + + # [B, C, H, W] -> [B, H, W, C] + x = self.downsample(x).permute(0, 2, 3, 1) + + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + + # [B, H, W, C] -> [B, C, H, W] x = x.permute(0, 3, 1, 2) return x @@ -415,7 +508,7 @@ class MetaFormer(nn.Module): token_mixers=nn.Identity, mlps=Mlp, mlp_fn=nn.Linear, - mlp_act = StarReLU, + mlp_act=StarReLU, mlp_bias=False, norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False), drop_path_rate=0., @@ -433,24 +526,19 @@ class MetaFormer(nn.Module): self.head_fn = head_fn self.num_features = dims[-1] self.drop_rate = drop_rate + self.num_stages = len(depths) + # convert everything to lists if they aren't indexable if not isinstance(depths, (list, tuple)): depths = [depths] # it means the model has only one stage if not isinstance(dims, (list, tuple)): dims = [dims] - - self.num_stages = len(depths) - if not isinstance(token_mixers, (list, tuple)): token_mixers = [token_mixers] * self.num_stages - if not isinstance(mlps, (list, tuple)): mlps = [mlps] * self.num_stages - if not isinstance(norm_layers, (list, tuple)): norm_layers = [norm_layers] * self.num_stages - - if not isinstance(layer_scale_init_values, (list, tuple)): layer_scale_init_values = [layer_scale_init_values] * self.num_stages if not isinstance(res_scale_init_values, (list, tuple)): @@ -459,47 +547,37 @@ class MetaFormer(nn.Module): self.grad_checkpointing = False self.feature_info = [] - dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] - + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] - self.stem = Downsampling( + self.stem = Stem( in_chans, dims[0], - kernel_size=7, - stride=4, - padding=2, - post_norm=downsample_norm + norm_layer=downsample_norm ) stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 + last_dim = dims[0] for i in range(self.num_stages): - stage = nn.Sequential(OrderedDict([ - ('downsample', nn.Identity() if i == 0 else Downsampling( - dims[i-1], - dims[i], - kernel_size=3, - stride=2, - padding=1, - pre_norm=downsample_norm, - pre_permute=False - )), - ('blocks', nn.Sequential(*[MetaFormerBlock( - dim=dims[i], - token_mixer=token_mixers[i], - mlp=mlps[i], - mlp_fn=mlp_fn, - mlp_act=mlp_act, - mlp_bias=mlp_bias, - norm_layer=norm_layers[i], - drop_path=dp_rates[cur + j], - layer_scale_init_value=layer_scale_init_values[i], - res_scale_init_value=res_scale_init_values[i] - ) for j in range(depths[i])]) - )]) + stage = MetaFormerStage( + last_dim, + dims[i], + depth=depths[i], + downsample_norm=downsample_norm, + token_mixer=token_mixers[i], + mlp=mlps[i], + mlp_fn=mlp_fn, + mlp_act=mlp_act, + mlp_bias=mlp_bias, + norm_layer=norm_layers[i], + dp_rates=dp_rates[i], + layer_scale_init_value=layer_scale_init_values[i], + res_scale_init_value=res_scale_init_values[i], ) + stages.append(stage) cur += depths[i] + last_dim = dims[i] self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) @@ -515,7 +593,7 @@ class MetaFormer(nn.Module): head = self.head_fn(dims[-1], num_classes) else: head = nn.Identity() - + self.norm_pre = output_norm(self.num_features) if head_norm_first else nn.Identity() self.head = nn.Sequential(OrderedDict([ ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), @@ -534,6 +612,8 @@ class MetaFormer(nn.Module): @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable + for stage in self.stages: + stage.set_grad_checkpointing(enable=enable) @torch.jit.ignore def get_classifier(self): @@ -552,23 +632,23 @@ class MetaFormer(nn.Module): head = nn.Identity() self.head.fc = head - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: Tensor, pre_logits: bool = False): # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( x = self.head.global_pool(x) x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.head.flatten(x) return x if pre_logits else self.head.fc(x) - def forward_features(self, x): + def forward_features(self, x: Tensor): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.stages, x) else: x = self.stages(x) - x = self.norm_pre(x) + x = self.norm_pre(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x - def forward(self, x): + def forward(self, x: Tensor): x = self.forward_features(x) x = self.forward_head(x) return x @@ -595,6 +675,8 @@ def checkpoint_filter_fn(state_dict, model): k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = k.replace('stages.0.downsample', 'patch_embed') k = k.replace('patch_embed', 'stem') + k = k.replace('post_norm', 'norm') + k = k.replace('pre_norm', 'norm') k = re.sub(r'^head', 'head.fc', k) k = re.sub(r'^norm', 'head.norm', k) out_dict[k] = v @@ -684,7 +766,7 @@ default_cfgs = generate_default_cfgs({ classifier='head.fc.fc2'), 'convformer_s18.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'convformer_s18.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth', classifier='head.fc.fc2'), From a73d414c7e58ad353c64315a14767fd1db8663c7 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Fri, 10 Feb 2023 03:33:41 -0800 Subject: [PATCH 102/102] Update metaformers.py --- timm/models/metaformers.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 36263e34..8cc373eb 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -772,7 +772,7 @@ default_cfgs = generate_default_cfgs({ classifier='head.fc.fc2'), 'convformer_s18.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'convformer_s18.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', classifier='head.fc.fc2', num_classes=21841), @@ -782,13 +782,13 @@ default_cfgs = generate_default_cfgs({ classifier='head.fc.fc2'), 'convformer_s36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'convformer_s36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth', classifier='head.fc.fc2'), 'convformer_s36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'convformer_s36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', classifier='head.fc.fc2', num_classes=21841), @@ -798,13 +798,13 @@ default_cfgs = generate_default_cfgs({ classifier='head.fc.fc2'), 'convformer_m36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'convformer_m36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth', classifier='head.fc.fc2'), 'convformer_m36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'convformer_m36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', classifier='head.fc.fc2', num_classes=21841), @@ -814,13 +814,13 @@ default_cfgs = generate_default_cfgs({ classifier='head.fc.fc2'), 'convformer_b36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'convformer_b36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth', classifier='head.fc.fc2'), 'convformer_b36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'convformer_b36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', classifier='head.fc.fc2', num_classes=21841), @@ -831,13 +831,13 @@ default_cfgs = generate_default_cfgs({ classifier='head.fc.fc2'), 'caformer_s18.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'caformer_s18.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth', classifier='head.fc.fc2'), 'caformer_s18.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'caformer_s18.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', classifier='head.fc.fc2', num_classes=21841), @@ -847,13 +847,13 @@ default_cfgs = generate_default_cfgs({ classifier='head.fc.fc2'), 'caformer_s36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'caformer_s36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth', classifier='head.fc.fc2'), 'caformer_s36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'caformer_s36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', classifier='head.fc.fc2', num_classes=21841), @@ -863,13 +863,13 @@ default_cfgs = generate_default_cfgs({ classifier='head.fc.fc2'), 'caformer_m36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'caformer_m36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth', classifier='head.fc.fc2'), 'caformer_m36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'caformer_m36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', classifier='head.fc.fc2', num_classes=21841), @@ -879,13 +879,13 @@ default_cfgs = generate_default_cfgs({ classifier='head.fc.fc2'), 'caformer_b36.sail_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'caformer_b36.sail_in22k_ft_in1k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth', classifier='head.fc.fc2'), 'caformer_b36.sail_in22k_ft_in1k_384': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', - classifier='head.fc.fc2', input_size=(3, 384, 384)), + classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)), 'caformer_b36.sail_in22k': _cfg( url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', classifier='head.fc.fc2', num_classes=21841),