Squashed commit of the following:

commit b7696a30a772dbbb2e00d81e7096c24dac97df73
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Fri Feb 10 01:46:44 2023 -0800

    Update metaformers.py

commit 41fe5c36263b40a6cd7caddb85b10c5d82d48023
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Fri Feb 10 01:03:47 2023 -0800

    Update metaformers.py

commit a3aee37c35985c01ca07902d860f809b648c612c
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Fri Feb 10 00:32:04 2023 -0800

    Update metaformers.py

commit f938beb81b4f46851d6d6f04ae7a9a74871ee40d
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Fri Feb 10 00:24:58 2023 -0800

    Update metaformers.py

commit 10bde717e51c95cdf20135c8bba77a7a1b00d78c
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sun Feb 5 02:11:28 2023 -0800

    Update metaformers.py

commit 39274bd45e78b8ead0509367f800121f0d7c25f4
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sun Feb 5 02:06:58 2023 -0800

    Update metaformers.py

commit a2329ab8ec00d0ebc00979690293c3887cc44a4c
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sun Feb 5 02:03:34 2023 -0800

    Update metaformers.py

commit 53b8ce5b8a6b6d828de61788bcc2e6043ebb3081
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sun Feb 5 02:02:37 2023 -0800

    Update metaformers.py

commit ab6225b9414f534815958036f6d5a392038d7ab2
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sun Feb 5 01:04:55 2023 -0800

    try NHWC

commit 02fcc30eaa67a3c92cae56f3062b2542c32c9283
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sat Feb 4 23:47:06 2023 -0800

    Update metaformers.py

commit 366aae93047934bd3d7d37a077e713d424fa429c
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Sat Feb 4 23:37:30 2023 -0800

    Stem/Downsample rework

commit 26a8e481a5cb2a32004a796bc25c6800cd2fb7b7
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Wed Feb 1 07:42:07 2023 -0800

    Update metaformers.py

commit a913f5d4384aa4b2f62fdab46254ae3772df00ee
Author: Fredo Guan <fredo.guan@hotmail.com>
Date:   Wed Feb 1 07:41:24 2023 -0800

    Update metaformers.py
pull/1647/head
Fredo Guan 2 years ago
parent 0b1f84142f
commit e2a9408dd0

@ -24,8 +24,11 @@ Adapted from https://github.com/sail-sg/metaformer, original copyright below
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1 from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1
from timm.layers.helpers import to_2tuple from timm.layers.helpers import to_2tuple
@ -40,28 +43,58 @@ from ._registry import register_model
__all__ = ['MetaFormer'] __all__ = ['MetaFormer']
class Stem(nn.Module):
"""
Stem implemented by a layer of convolution.
Conv2d params constant across all models.
"""
def __init__(self,
in_channels,
out_channels,
norm_layer=None,
):
super().__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=7,
stride=4,
padding=2
)
self.norm = norm_layer(out_channels) if norm_layer else nn.Identity()
def forward(self, x):
x = self.conv(x)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
# [B, C, H, W]
return x
class Downsampling(nn.Module): class Downsampling(nn.Module):
""" """
Downsampling implemented by a layer of convolution. Downsampling implemented by a layer of convolution.
""" """
def __init__(self, in_channels, out_channels, def __init__(self,
kernel_size, stride=1, padding=0, in_channels,
pre_norm=None, post_norm=None, pre_permute=False): out_channels,
kernel_size,
stride=1,
padding=0,
norm_layer=None,
):
super().__init__() super().__init__()
self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() self.norm = norm_layer(in_channels) if norm_layer else nn.Identity()
self.pre_permute = pre_permute self.conv = nn.Conv2d(
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, in_channels,
stride=stride, padding=padding) out_channels,
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() kernel_size=kernel_size,
stride=stride,
padding=padding
)
def forward(self, x): def forward(self, x):
if self.pre_permute: x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
# if take [B, H, W, C] as input, permute it to [B, C, H, W]
x = x.permute(0, 3, 1, 2)
x = self.pre_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.conv(x) x = self.conv(x)
x = self.post_norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x return x
@ -299,7 +332,6 @@ class Mlp(nn.Module):
return x return x
class MlpHead(nn.Module): class MlpHead(nn.Module):
""" MLP classification head """ MLP classification head
""" """
@ -323,7 +355,6 @@ class MlpHead(nn.Module):
return x return x
class MetaFormerBlock(nn.Module): class MetaFormerBlock(nn.Module):
""" """
Implementation of one MetaFormer block. Implementation of one MetaFormer block.
@ -367,7 +398,6 @@ class MetaFormerBlock(nn.Module):
if res_scale_init_value else nn.Identity() if res_scale_init_value else nn.Identity()
def forward(self, x): def forward(self, x):
x = x.permute(0, 2, 3, 1)
x = self.res_scale1(x) + \ x = self.res_scale1(x) + \
self.layer_scale1( self.layer_scale1(
self.drop_path1( self.drop_path1(
@ -380,6 +410,69 @@ class MetaFormerBlock(nn.Module):
self.mlp(self.norm2(x)) self.mlp(self.norm2(x))
) )
) )
return x
class MetaFormerStage(nn.Module):
# implementation of a single metaformer stage
def __init__(
self,
in_chs,
out_chs,
depth=2,
downsample_norm=partial(LayerNormGeneral, bias=False, eps=1e-6),
token_mixer=nn.Identity,
mlp=Mlp,
mlp_fn=nn.Linear,
mlp_act=StarReLU,
mlp_bias=False,
norm_layer=partial(LayerNormGeneral, eps=1e-6, bias=False),
dp_rates=[0.]*2,
layer_scale_init_value=None,
res_scale_init_value=None,
):
super().__init__()
self.grad_checkpointing = False
# don't downsample if in_chs and out_chs are the same
self.downsample = nn.Identity() if in_chs == out_chs else Downsampling(
in_chs,
out_chs,
kernel_size=3,
stride=2,
padding=1,
norm_layer=downsample_norm
)
self.blocks = nn.Sequential(*[MetaFormerBlock(
dim=out_chs,
token_mixer=token_mixer,
mlp=mlp,
mlp_fn=mlp_fn,
mlp_act=mlp_act,
mlp_bias=mlp_bias,
norm_layer=norm_layer,
drop_path=dp_rates[i],
layer_scale_init_value=layer_scale_init_value,
res_scale_init_value=res_scale_init_value
) for i in range(depth)])
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
# Permute to channels-first for feature extraction
def forward(self, x: Tensor):
# [B, C, H, W] -> [B, H, W, C]
x = self.downsample(x).permute(0, 2, 3, 1)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
# [B, H, W, C] -> [B, C, H, W]
x = x.permute(0, 3, 1, 2) x = x.permute(0, 3, 1, 2)
return x return x
@ -415,7 +508,7 @@ class MetaFormer(nn.Module):
token_mixers=nn.Identity, token_mixers=nn.Identity,
mlps=Mlp, mlps=Mlp,
mlp_fn=nn.Linear, mlp_fn=nn.Linear,
mlp_act = StarReLU, mlp_act=StarReLU,
mlp_bias=False, mlp_bias=False,
norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False), norm_layers=partial(LayerNormGeneral, eps=1e-6, bias=False),
drop_path_rate=0., drop_path_rate=0.,
@ -433,24 +526,19 @@ class MetaFormer(nn.Module):
self.head_fn = head_fn self.head_fn = head_fn
self.num_features = dims[-1] self.num_features = dims[-1]
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.num_stages = len(depths)
# convert everything to lists if they aren't indexable
if not isinstance(depths, (list, tuple)): if not isinstance(depths, (list, tuple)):
depths = [depths] # it means the model has only one stage depths = [depths] # it means the model has only one stage
if not isinstance(dims, (list, tuple)): if not isinstance(dims, (list, tuple)):
dims = [dims] dims = [dims]
self.num_stages = len(depths)
if not isinstance(token_mixers, (list, tuple)): if not isinstance(token_mixers, (list, tuple)):
token_mixers = [token_mixers] * self.num_stages token_mixers = [token_mixers] * self.num_stages
if not isinstance(mlps, (list, tuple)): if not isinstance(mlps, (list, tuple)):
mlps = [mlps] * self.num_stages mlps = [mlps] * self.num_stages
if not isinstance(norm_layers, (list, tuple)): if not isinstance(norm_layers, (list, tuple)):
norm_layers = [norm_layers] * self.num_stages norm_layers = [norm_layers] * self.num_stages
if not isinstance(layer_scale_init_values, (list, tuple)): if not isinstance(layer_scale_init_values, (list, tuple)):
layer_scale_init_values = [layer_scale_init_values] * self.num_stages layer_scale_init_values = [layer_scale_init_values] * self.num_stages
if not isinstance(res_scale_init_values, (list, tuple)): if not isinstance(res_scale_init_values, (list, tuple)):
@ -459,47 +547,37 @@ class MetaFormer(nn.Module):
self.grad_checkpointing = False self.grad_checkpointing = False
self.feature_info = [] self.feature_info = []
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
self.stem = Downsampling( self.stem = Stem(
in_chans, in_chans,
dims[0], dims[0],
kernel_size=7, norm_layer=downsample_norm
stride=4,
padding=2,
post_norm=downsample_norm
) )
stages = nn.ModuleList() # each stage consists of multiple metaformer blocks stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
cur = 0 cur = 0
last_dim = dims[0]
for i in range(self.num_stages): for i in range(self.num_stages):
stage = nn.Sequential(OrderedDict([ stage = MetaFormerStage(
('downsample', nn.Identity() if i == 0 else Downsampling( last_dim,
dims[i-1], dims[i],
dims[i], depth=depths[i],
kernel_size=3, downsample_norm=downsample_norm,
stride=2, token_mixer=token_mixers[i],
padding=1, mlp=mlps[i],
pre_norm=downsample_norm, mlp_fn=mlp_fn,
pre_permute=False mlp_act=mlp_act,
)), mlp_bias=mlp_bias,
('blocks', nn.Sequential(*[MetaFormerBlock( norm_layer=norm_layers[i],
dim=dims[i], dp_rates=dp_rates[i],
token_mixer=token_mixers[i], layer_scale_init_value=layer_scale_init_values[i],
mlp=mlps[i], res_scale_init_value=res_scale_init_values[i],
mlp_fn=mlp_fn,
mlp_act=mlp_act,
mlp_bias=mlp_bias,
norm_layer=norm_layers[i],
drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_values[i],
res_scale_init_value=res_scale_init_values[i]
) for j in range(depths[i])])
)])
) )
stages.append(stage) stages.append(stage)
cur += depths[i] cur += depths[i]
last_dim = dims[i]
self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')] self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
@ -515,7 +593,7 @@ class MetaFormer(nn.Module):
head = self.head_fn(dims[-1], num_classes) head = self.head_fn(dims[-1], num_classes)
else: else:
head = nn.Identity() head = nn.Identity()
self.norm_pre = output_norm(self.num_features) if head_norm_first else nn.Identity() self.norm_pre = output_norm(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([ self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
@ -534,6 +612,8 @@ class MetaFormer(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable=True): def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable self.grad_checkpointing = enable
for stage in self.stages:
stage.set_grad_checkpointing(enable=enable)
@torch.jit.ignore @torch.jit.ignore
def get_classifier(self): def get_classifier(self):
@ -552,23 +632,23 @@ class MetaFormer(nn.Module):
head = nn.Identity() head = nn.Identity()
self.head.fc = head self.head.fc = head
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x: Tensor, pre_logits: bool = False):
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
x = self.head.global_pool(x) x = self.head.global_pool(x)
x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.head.flatten(x) x = self.head.flatten(x)
return x if pre_logits else self.head.fc(x) return x if pre_logits else self.head.fc(x)
def forward_features(self, x): def forward_features(self, x: Tensor):
x = self.stem(x) x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x) x = checkpoint_seq(self.stages, x)
else: else:
x = self.stages(x) x = self.stages(x)
x = self.norm_pre(x) x = self.norm_pre(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x return x
def forward(self, x): def forward(self, x: Tensor):
x = self.forward_features(x) x = self.forward_features(x)
x = self.forward_head(x) x = self.forward_head(x)
return x return x
@ -595,6 +675,8 @@ def checkpoint_filter_fn(state_dict, model):
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k) k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
k = k.replace('stages.0.downsample', 'patch_embed') k = k.replace('stages.0.downsample', 'patch_embed')
k = k.replace('patch_embed', 'stem') k = k.replace('patch_embed', 'stem')
k = k.replace('post_norm', 'norm')
k = k.replace('pre_norm', 'norm')
k = re.sub(r'^head', 'head.fc', k) k = re.sub(r'^head', 'head.fc', k)
k = re.sub(r'^norm', 'head.norm', k) k = re.sub(r'^norm', 'head.norm', k)
out_dict[k] = v out_dict[k] = v
@ -684,7 +766,7 @@ default_cfgs = generate_default_cfgs({
classifier='head.fc.fc2'), classifier='head.fc.fc2'),
'convformer_s18.sail_in1k_384': _cfg( 'convformer_s18.sail_in1k_384': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth',
classifier='head.fc.fc2', input_size=(3, 384, 384)), classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12,12)),
'convformer_s18.sail_in22k_ft_in1k': _cfg( 'convformer_s18.sail_in22k_ft_in1k': _cfg(
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth', url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth',
classifier='head.fc.fc2'), classifier='head.fc.fc2'),

Loading…
Cancel
Save