Squashed commit of the following:

commit b7696a30a772dbbb2e00d81e7096c24dac97df73
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Fri Feb 10 01:46:44 2023 -0800

    Update metaformers.py

commit 41fe5c36263b40a6cd7caddb85b10c5d82d48023
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Fri Feb 10 01:03:47 2023 -0800

    Update metaformers.py

commit a3aee37c35985c01ca07902d860f809b648c612c
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Fri Feb 10 00:32:04 2023 -0800

    Update metaformers.py

commit f938beb81b4f46851d6d6f04ae7a9a74871ee40d
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Fri Feb 10 00:24:58 2023 -0800

    Update metaformers.py

commit 10bde717e51c95cdf20135c8bba77a7a1b00d78c
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sun Feb 5 02:11:28 2023 -0800

    Update metaformers.py

commit 39274bd45e78b8ead0509367f800121f0d7c25f4
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sun Feb 5 02:06:58 2023 -0800

    Update metaformers.py

commit a2329ab8ec00d0ebc00979690293c3887cc44a4c
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sun Feb 5 02:03:34 2023 -0800

    Update metaformers.py

commit 53b8ce5b8a6b6d828de61788bcc2e6043ebb3081
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sun Feb 5 02:02:37 2023 -0800

    Update metaformers.py

commit ab6225b9414f534815958036f6d5a392038d7ab2
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sun Feb 5 01:04:55 2023 -0800

    try NHWC

commit 02fcc30eaa67a3c92cae56f3062b2542c32c9283
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sat Feb 4 23:47:06 2023 -0800

    Update metaformers.py

commit 366aae93047934bd3d7d37a077e713d424fa429c
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sat Feb 4 23:37:30 2023 -0800

    Stem/Downsample rework

commit 26a8e481a5cb2a32004a796bc25c6800cd2fb7b7
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Wed Feb 1 07:42:07 2023 -0800

    Update metaformers.py

commit a913f5d4384aa4b2f62fdab46254ae3772df00ee
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Wed Feb 1 07:41:24 2023 -0800

    Update metaformers.py
pull/1647/head
Fredo Guan 1 year ago
parent 0b1f84142f
commit e2a9408dd0

@ -24,8 +24,11 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1
from timm.layers.helpers import to_2tuple
@ -40,28 +43,58 @@ from ._registry import register_model
__all__ = ['MetaFormer']
class Stem(nn.Module):
"""
Stem implemented by a layer of convolution.
Conv2d params constant across all models.
"""
def __init__(self,
in_channels,
out_channels,
norm_layer=None,
):
super().__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=7,
stride=4,
padding=2
)
self.norm = norm_layer(out_channels) if norm_layer else nn.Identity()
def forward(self, x):
x = self.conv(x)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
# [B, C, H, W]
return x
class Downsampling(nn.Module):
"""
Downsampling implemented by a layer of convolution.
"""
def __init__(self, in_channels, out_channels,
kernel_size, stride=1, padding=0,
pre_norm=None, post_norm=None, pre_permute=False):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
norm_layer=None,
):
super().__init__()
self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity()
self.pre_permute = pre_permute
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=padding)
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
self.norm = norm_layer(in_channels) if norm_layer else nn.Identity()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
def forward(self, x):
if self.pre_permute:
# if take [B, H, W, C] as input, permute it to [B, C, H, W]
x = x.permute(0, 3, 1, 2)
x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.conv(x)
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
@ -299,7 +332,6 @@ class Mlp(nn.Module):
return x
class MlpHead(nn.Module):
""" MLP classification head
"""
@ -323,7 +355,6 @@ class MlpHead(nn.Module):
return x
class MetaFormerBlock(nn.Module):
"""
Implementation of one MetaFormer block.
@ -367,7 +398,6 @@ class MetaFormerBlock(nn.Module):
if res_scale_init_value else nn.Identity()
def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.res_scale1(x) + \
self.layer_scale1(
self.drop_path1(
@ -380,6 +410,69 @@ class MetaFormerBlock(nn.Module):
self.mlp(self.norm2(x))
)
)
return x
class MetaFormerStage(nn.Module):
# implementation of a single metaformer stage
def __init__(
self,
in_chs,
out_chs,
depth=2,
downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6),
token_mixer=nn.Identity,
mlp=Mlp,
mlp_fn=nn.Linear,
mlp_act=StarReLU,
mlp_bias=False,
norm_layer=partial(LayerNormGeneral, eps=1e-6, bias=False),
dp_rates=[0.]*2,
layer_scale_init_value=None,
res_scale_init_value=None,
):
super().__init__()
self.grad_checkpointing = False
# don't downsample if in_chs and out_chs are the same
self.downsample = nn.Identity() if in_chs == out_chs else Downsampling(
in_chs,
out_chs,
kernel_size=3,
stride=2,
padding=1,
norm_layer=downsample_norm
)
self.blocks = nn.Sequential(*[MetaFormerBlock(
dim=out_chs,
token_mixer=token_mixer,
mlp=mlp,
mlp_fn=mlp_fn,
mlp_act=mlp_act,
mlp_bias=mlp_bias,
norm_layer=norm_layer,
drop_path=dp_rates[i],
layer_scale_init_value=layer_scale_init_value,
res_scale_init_value=res_scale_init_value
) for i in range(depth)])
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
# Permute to channels-first for feature extraction
def forward(self, x: Tensor):
# [B, C, H, W] -> [B, H, W, C]
x = self.downsample(x).permute(0, 2, 3, 1)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
# [B, H, W, C] -> [B, C, H, W]
x = x.permute(0, 3, 1, 2)
return x
@ -415,7 +508,7 @@ class MetaFormer(nn.Module):
token_mixers=nn.Identity,
mlps=Mlp,
mlp_fn=nn.Linear,
mlp_act = StarReLU,
mlp_act=StarReLU,
mlp_bias=False,
norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False),
drop_path_rate=0.,
@ -433,24 +526,19 @@ class MetaFormer(nn.Module):
self.head_fn = head_fn
self.num_features = dims[-1]
self.drop_rate = drop_rate
self.num_stages = len(depths)
# convert everything to lists if they aren't indexable
if not isinstance(depths, (list, tuple)):
depths = [depths] # it means the model has only one stage
if not isinstance(dims, (list, tuple)):
dims = [dims]
self.num_stages = len(depths)
if not isinstance(token_mixers, (list, tuple)):
token_mixers = [token_mixers] * self.num_stages
if not isinstance(mlps, (list, tuple)):
mlps = [mlps] * self.num_stages
if not isinstance(norm_layers, (list, tuple)):
norm_layers = [norm_layers] * self.num_stages
if not isinstance(layer_scale_init_values, (list, tuple)):
layer_scale_init_values = [layer_scale_init_values] * self.num_stages
if not isinstance(res_scale_init_values, (list, tuple)):
@ -459,47 +547,37 @@ class MetaFormer(nn.Module):
self.grad_checkpointing = False
self.feature_info = []
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
self.stem = Downsampling(
self.stem = Stem(
in_chans,
dims[0],
kernel_size=7,
stride=4,
padding=2,
post_norm=downsample_norm
norm_layer=downsample_norm
)
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
cur = 0
last_dim = dims[0]
for i in range(self.num_stages):
stage = nn.Sequential(OrderedDict([
('downsample', nn.Identity() if i == 0 else Downsampling(
dims[i-1],
dims[i],
kernel_size=3,
stride=2,
padding=1,
pre_norm=downsample_norm,
pre_permute=False
)),
('blocks', nn.Sequential(*[MetaFormerBlock(
dim=dims[i],
token_mixer=token_mixers[i],
mlp=mlps[i],
mlp_fn=mlp_fn,
mlp_act=mlp_act,
mlp_bias=mlp_bias,
norm_layer=norm_layers[i],
drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_values[i],
res_scale_init_value=res_scale_init_values[i]
) for j in range(depths[i])])
)])
stage = MetaFormerStage(
last_dim,
dims[i],
depth=depths[i],
downsample_norm=downsample_norm,
token_mixer=token_mixers[i],
mlp=mlps[i],
mlp_fn=mlp_fn,
mlp_act=mlp_act,
mlp_bias=mlp_bias,
norm_layer=norm_layers[i],
dp_rates=dp_rates[i],
layer_scale_init_value=layer_scale_init_values[i],
res_scale_init_value=res_scale_init_values[i],
)
stages.append(stage)
cur += depths[i]
last_dim = dims[i]
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
@ -515,7 +593,7 @@ class MetaFormer(nn.Module):
head = self.head_fn(dims[-1], num_classes)
else:
head = nn.Identity()
self.norm_pre = output_norm(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
@ -534,6 +612,8 @@ class MetaFormer(nn.Module):
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
for stage in self.stages:
stage.set_grad_checkpointing(enable=enable)
@torch.jit.ignore
def get_classifier(self):
@ -552,23 +632,23 @@ class MetaFormer(nn.Module):
head = nn.Identity()
self.head.fc = head
def forward_head(self, x, pre_logits: bool = False):
def forward_head(self, x: Tensor, pre_logits: bool = False):
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
x = self.head.global_pool(x)
x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.head.flatten(x)
return x if pre_logits else self.head.fc(x)
def forward_features(self, x):
def forward_features(self, x: Tensor):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.norm_pre(x)
x = self.norm_pre(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
def forward(self, x):
def forward(self, x: Tensor):
x = self.forward_features(x)
x = self.forward_head(x)
return x
@ -595,6 +675,8 @@ def checkpoint_filter_fn(state_dict, model):
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
k = k.replace('stages.0.downsample', 'patch_embed')
k = k.replace('patch_embed', 'stem')
k = k.replace('post_norm', 'norm')
k = k.replace('pre_norm', 'norm')
k = re.sub(r'^head', 'head.fc', k)
k = re.sub(r'^norm', 'head.norm', k)
out_dict[k] = v
@ -684,7 +766,7 @@ default_cfgs = generate_default_cfgs({
classifier='head.fc.fc2'),
'convformer_s18.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth',
classifier='head.fc.fc2', input_size=(3, 384, 384)),
classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)),
'convformer_s18.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth',
classifier='head.fc.fc2'),

Loading…
Cancel
Save