|
|
|
@ -24,8 +24,11 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1
|
|
|
|
|
from timm.layers.helpers import to_2tuple
|
|
|
|
@ -40,28 +43,58 @@ from ._registry import register_model
|
|
|
|
|
|
|
|
|
|
__all__ = ['MetaFormer']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Stem(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Stem implemented by a layer of convolution.
|
|
|
|
|
Conv2d params constant across all models.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self,
|
|
|
|
|
in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
norm_layer=None,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.conv = nn.Conv2d(
|
|
|
|
|
in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
kernel_size=7,
|
|
|
|
|
stride=4,
|
|
|
|
|
padding=2
|
|
|
|
|
)
|
|
|
|
|
self.norm = norm_layer(out_channels) if norm_layer else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.conv(x)
|
|
|
|
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
# [B, C, H, W]
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class Downsampling(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Downsampling implemented by a layer of convolution.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, in_channels, out_channels,
|
|
|
|
|
kernel_size, stride=1, padding=0,
|
|
|
|
|
pre_norm=None, post_norm=None, pre_permute=False):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
kernel_size,
|
|
|
|
|
stride=1,
|
|
|
|
|
padding=0,
|
|
|
|
|
norm_layer=None,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity()
|
|
|
|
|
self.pre_permute = pre_permute
|
|
|
|
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
|
|
|
|
|
stride=stride, padding=padding)
|
|
|
|
|
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
|
|
|
|
|
self.norm = norm_layer(in_channels) if norm_layer else nn.Identity()
|
|
|
|
|
self.conv = nn.Conv2d(
|
|
|
|
|
in_channels,
|
|
|
|
|
out_channels,
|
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
|
stride=stride,
|
|
|
|
|
padding=padding
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
if self.pre_permute:
|
|
|
|
|
# if take [B, H, W, C] as input, permute it to [B, C, H, W]
|
|
|
|
|
x = x.permute(0, 3, 1, 2)
|
|
|
|
|
x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
|
|
|
|
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
x = self.conv(x)
|
|
|
|
|
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -299,7 +332,6 @@ class Mlp(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MlpHead(nn.Module):
|
|
|
|
|
""" MLP classification head
|
|
|
|
|
"""
|
|
|
|
@ -323,7 +355,6 @@ class MlpHead(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MetaFormerBlock(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Implementation of one MetaFormer block.
|
|
|
|
@ -367,7 +398,6 @@ class MetaFormerBlock(nn.Module):
|
|
|
|
|
if res_scale_init_value else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = x.permute(0, 2, 3, 1)
|
|
|
|
|
x = self.res_scale1(x) + \
|
|
|
|
|
self.layer_scale1(
|
|
|
|
|
self.drop_path1(
|
|
|
|
@ -380,6 +410,69 @@ class MetaFormerBlock(nn.Module):
|
|
|
|
|
self.mlp(self.norm2(x))
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class MetaFormerStage(nn.Module):
|
|
|
|
|
# implementation of a single metaformer stage
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
depth=2,
|
|
|
|
|
downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6),
|
|
|
|
|
token_mixer=nn.Identity,
|
|
|
|
|
mlp=Mlp,
|
|
|
|
|
mlp_fn=nn.Linear,
|
|
|
|
|
mlp_act=StarReLU,
|
|
|
|
|
mlp_bias=False,
|
|
|
|
|
norm_layer=partial(LayerNormGeneral, eps=1e-6, bias=False),
|
|
|
|
|
dp_rates=[0.]*2,
|
|
|
|
|
layer_scale_init_value=None,
|
|
|
|
|
res_scale_init_value=None,
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
# don't downsample if in_chs and out_chs are the same
|
|
|
|
|
self.downsample = nn.Identity() if in_chs == out_chs else Downsampling(
|
|
|
|
|
in_chs,
|
|
|
|
|
out_chs,
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=2,
|
|
|
|
|
padding=1,
|
|
|
|
|
norm_layer=downsample_norm
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.Sequential(*[MetaFormerBlock(
|
|
|
|
|
dim=out_chs,
|
|
|
|
|
token_mixer=token_mixer,
|
|
|
|
|
mlp=mlp,
|
|
|
|
|
mlp_fn=mlp_fn,
|
|
|
|
|
mlp_act=mlp_act,
|
|
|
|
|
mlp_bias=mlp_bias,
|
|
|
|
|
norm_layer=norm_layer,
|
|
|
|
|
drop_path=dp_rates[i],
|
|
|
|
|
layer_scale_init_value=layer_scale_init_value,
|
|
|
|
|
res_scale_init_value=res_scale_init_value
|
|
|
|
|
) for i in range(depth)])
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def set_grad_checkpointing(self, enable=True):
|
|
|
|
|
self.grad_checkpointing = enable
|
|
|
|
|
|
|
|
|
|
# Permute to channels-first for feature extraction
|
|
|
|
|
def forward(self, x: Tensor):
|
|
|
|
|
|
|
|
|
|
# [B, C, H, W] -> [B, H, W, C]
|
|
|
|
|
x = self.downsample(x).permute(0, 2, 3, 1)
|
|
|
|
|
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
x = checkpoint_seq(self.blocks, x)
|
|
|
|
|
else:
|
|
|
|
|
x = self.blocks(x)
|
|
|
|
|
|
|
|
|
|
# [B, H, W, C] -> [B, C, H, W]
|
|
|
|
|
x = x.permute(0, 3, 1, 2)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
@ -415,7 +508,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
token_mixers=nn.Identity,
|
|
|
|
|
mlps=Mlp,
|
|
|
|
|
mlp_fn=nn.Linear,
|
|
|
|
|
mlp_act = StarReLU,
|
|
|
|
|
mlp_act=StarReLU,
|
|
|
|
|
mlp_bias=False,
|
|
|
|
|
norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False),
|
|
|
|
|
drop_path_rate=0.,
|
|
|
|
@ -433,24 +526,19 @@ class MetaFormer(nn.Module):
|
|
|
|
|
self.head_fn = head_fn
|
|
|
|
|
self.num_features = dims[-1]
|
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
|
self.num_stages = len(depths)
|
|
|
|
|
|
|
|
|
|
# convert everything to lists if they aren't indexable
|
|
|
|
|
if not isinstance(depths, (list, tuple)):
|
|
|
|
|
depths = [depths] # it means the model has only one stage
|
|
|
|
|
if not isinstance(dims, (list, tuple)):
|
|
|
|
|
dims = [dims]
|
|
|
|
|
|
|
|
|
|
self.num_stages = len(depths)
|
|
|
|
|
|
|
|
|
|
if not isinstance(token_mixers, (list, tuple)):
|
|
|
|
|
token_mixers = [token_mixers] * self.num_stages
|
|
|
|
|
|
|
|
|
|
if not isinstance(mlps, (list, tuple)):
|
|
|
|
|
mlps = [mlps] * self.num_stages
|
|
|
|
|
|
|
|
|
|
if not isinstance(norm_layers, (list, tuple)):
|
|
|
|
|
norm_layers = [norm_layers] * self.num_stages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(layer_scale_init_values, (list, tuple)):
|
|
|
|
|
layer_scale_init_values = [layer_scale_init_values] * self.num_stages
|
|
|
|
|
if not isinstance(res_scale_init_values, (list, tuple)):
|
|
|
|
@ -459,47 +547,37 @@ class MetaFormer(nn.Module):
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
|
|
|
|
|
|
|
|
|
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
|
|
|
|
|
|
|
|
|
self.stem = Downsampling(
|
|
|
|
|
self.stem = Stem(
|
|
|
|
|
in_chans,
|
|
|
|
|
dims[0],
|
|
|
|
|
kernel_size=7,
|
|
|
|
|
stride=4,
|
|
|
|
|
padding=2,
|
|
|
|
|
post_norm=downsample_norm
|
|
|
|
|
norm_layer=downsample_norm
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
|
|
|
|
cur = 0
|
|
|
|
|
last_dim = dims[0]
|
|
|
|
|
for i in range(self.num_stages):
|
|
|
|
|
stage = nn.Sequential(OrderedDict([
|
|
|
|
|
('downsample', nn.Identity() if i == 0 else Downsampling(
|
|
|
|
|
dims[i-1],
|
|
|
|
|
dims[i],
|
|
|
|
|
kernel_size=3,
|
|
|
|
|
stride=2,
|
|
|
|
|
padding=1,
|
|
|
|
|
pre_norm=downsample_norm,
|
|
|
|
|
pre_permute=False
|
|
|
|
|
)),
|
|
|
|
|
('blocks', nn.Sequential(*[MetaFormerBlock(
|
|
|
|
|
dim=dims[i],
|
|
|
|
|
token_mixer=token_mixers[i],
|
|
|
|
|
mlp=mlps[i],
|
|
|
|
|
mlp_fn=mlp_fn,
|
|
|
|
|
mlp_act=mlp_act,
|
|
|
|
|
mlp_bias=mlp_bias,
|
|
|
|
|
norm_layer=norm_layers[i],
|
|
|
|
|
drop_path=dp_rates[cur + j],
|
|
|
|
|
layer_scale_init_value=layer_scale_init_values[i],
|
|
|
|
|
res_scale_init_value=res_scale_init_values[i]
|
|
|
|
|
) for j in range(depths[i])])
|
|
|
|
|
)])
|
|
|
|
|
stage = MetaFormerStage(
|
|
|
|
|
last_dim,
|
|
|
|
|
dims[i],
|
|
|
|
|
depth=depths[i],
|
|
|
|
|
downsample_norm=downsample_norm,
|
|
|
|
|
token_mixer=token_mixers[i],
|
|
|
|
|
mlp=mlps[i],
|
|
|
|
|
mlp_fn=mlp_fn,
|
|
|
|
|
mlp_act=mlp_act,
|
|
|
|
|
mlp_bias=mlp_bias,
|
|
|
|
|
norm_layer=norm_layers[i],
|
|
|
|
|
dp_rates=dp_rates[i],
|
|
|
|
|
layer_scale_init_value=layer_scale_init_values[i],
|
|
|
|
|
res_scale_init_value=res_scale_init_values[i],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
stages.append(stage)
|
|
|
|
|
cur += depths[i]
|
|
|
|
|
last_dim = dims[i]
|
|
|
|
|
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
|
|
|
|
|
|
|
|
|
|
self.stages = nn.Sequential(*stages)
|
|
|
|
@ -515,7 +593,7 @@ class MetaFormer(nn.Module):
|
|
|
|
|
head = self.head_fn(dims[-1], num_classes)
|
|
|
|
|
else:
|
|
|
|
|
head = nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.norm_pre = output_norm(self.num_features) if head_norm_first else nn.Identity()
|
|
|
|
|
self.head = nn.Sequential(OrderedDict([
|
|
|
|
|
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
|
|
|
|
@ -534,6 +612,8 @@ class MetaFormer(nn.Module):
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def set_grad_checkpointing(self, enable=True):
|
|
|
|
|
self.grad_checkpointing = enable
|
|
|
|
|
for stage in self.stages:
|
|
|
|
|
stage.set_grad_checkpointing(enable=enable)
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def get_classifier(self):
|
|
|
|
@ -552,23 +632,23 @@ class MetaFormer(nn.Module):
|
|
|
|
|
head = nn.Identity()
|
|
|
|
|
self.head.fc = head
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
|
def forward_head(self, x: Tensor, pre_logits: bool = False):
|
|
|
|
|
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
|
|
|
|
|
x = self.head.global_pool(x)
|
|
|
|
|
x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
x = self.head.flatten(x)
|
|
|
|
|
return x if pre_logits else self.head.fc(x)
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
def forward_features(self, x: Tensor):
|
|
|
|
|
x = self.stem(x)
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
x = checkpoint_seq(self.stages, x)
|
|
|
|
|
else:
|
|
|
|
|
x = self.stages(x)
|
|
|
|
|
x = self.norm_pre(x)
|
|
|
|
|
x = self.norm_pre(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
def forward(self, x: Tensor):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
x = self.forward_head(x)
|
|
|
|
|
return x
|
|
|
|
@ -595,6 +675,8 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
|
|
|
|
|
k = k.replace('stages.0.downsample', 'patch_embed')
|
|
|
|
|
k = k.replace('patch_embed', 'stem')
|
|
|
|
|
k = k.replace('post_norm', 'norm')
|
|
|
|
|
k = k.replace('pre_norm', 'norm')
|
|
|
|
|
k = re.sub(r'^head', 'head.fc', k)
|
|
|
|
|
k = re.sub(r'^norm', 'head.norm', k)
|
|
|
|
|
out_dict[k] = v
|
|
|
|
@ -684,7 +766,7 @@ default_cfgs = generate_default_cfgs({
|
|
|
|
|
classifier='head.fc.fc2'),
|
|
|
|
|
'convformer_s18.sail_in1k_384': _cfg(
|
|
|
|
|
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth',
|
|
|
|
|
classifier='head.fc.fc2', input_size=(3, 384, 384)),
|
|
|
|
|
classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)),
|
|
|
|
|
'convformer_s18.sail_in22k_ft_in1k': _cfg(
|
|
|
|
|
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth',
|
|
|
|
|
classifier='head.fc.fc2'),
|
|
|
|
|