diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 90241f5c..4aae99e3 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -20,7 +20,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible from .inplace_abn import InplaceAbn from .linear import Linear from .mixed_conv2d import MixedConv2d -from .mlp import Mlp, GluMlp +from .mlp import Mlp, GluMlp, GatedMlp from .norm import GroupNorm from .norm_act import BatchNormAct2d, GroupNormAct from .padding import get_padding, get_same_padding, pad_same diff --git a/timm/models/layers/mlp.py b/timm/models/layers/mlp.py index b65c8d07..b3f8de11 100644 --- a/timm/models/layers/mlp.py +++ b/timm/models/layers/mlp.py @@ -34,9 +34,10 @@ class GluMlp(nn.Module): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features * 2) + assert hidden_features % 2 == 0 + self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = nn.Linear(hidden_features // 2, out_features) self.drop = nn.Dropout(drop) def forward(self, x): @@ -47,3 +48,32 @@ class GluMlp(nn.Module): x = self.fc2(x) x = self.drop(x) return x + + +class GatedMlp(nn.Module): + """ MLP as used in gMLP + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + gate_layer=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + if gate_layer is not None: + assert hidden_features % 2 == 0 + self.gate = gate_layer(hidden_features) + hidden_features = hidden_features // 2 # FIXME base reduction on gate property? + else: + self.gate = nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.gate(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 248568fc..2241fe43 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -1,4 +1,6 @@ -""" MLP-Mixer in PyTorch +""" MLP-Mixer, ResMLP, and gMLP in PyTorch + +This impl originally based on MLP-Mixer paper. Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py @@ -12,6 +14,25 @@ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2 year={2021} } +Also supporting preliminary (not verified) implementations of ResMlp, gMLP, and possibly more... + +Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 +@misc{touvron2021resmlp, + title={ResMLP: Feedforward networks for image classification with data-efficient training}, + author={Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and + Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou}, + year={2021}, + eprint={2105.03404}, +} + +Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 +@misc{liu2021pay, + title={Pay Attention to MLPs}, + author={Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le}, + year={2021}, + eprint={2105.08050}, +} + A thank you to paper authors for releasing code and weights. Hacked together by / Copyright 2021 Ross Wightman @@ -25,7 +46,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import PatchEmbed, Mlp, GluMlp, DropPath, lecun_normal_ +from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple from .registry import register_model @@ -43,7 +64,6 @@ def _cfg(url='', **kwargs): default_cfgs = dict( mixer_s32_224=_cfg(), mixer_s16_224=_cfg(), - mixer_s16_glu_224=_cfg(), mixer_b32_224=_cfg(), mixer_b16_224=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth', @@ -60,15 +80,29 @@ default_cfgs = dict( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth', num_classes=21843 ), + + gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + gmixer_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + resmlp_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + resmlp_36_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + + gmlp_ti16_224=_cfg(), + gmlp_s16_224=_cfg(), + gmlp_b16_224=_cfg(), ) class MixerBlock(nn.Module): - + """ Residual Block w/ token mixing and channel MLPs + Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 + """ def __init__( - self, dim, seq_len, tokens_dim, channels_dim, - mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): + self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): super().__init__() + tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)] self.norm1 = norm_layer(dim) self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -81,6 +115,78 @@ class MixerBlock(nn.Module): return x +class Affine(nn.Module): + def __init__(self, dim): + super().__init__() + self.alpha = nn.Parameter(torch.ones((1, 1, dim))) + self.beta = nn.Parameter(torch.zeros((1, 1, dim))) + + def forward(self, x): + return torch.addcmul(self.beta, self.alpha, x) + + +class ResBlock(nn.Module): + """ Residual MLP block w/ LayerScale and Affine 'norm' + + Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + def __init__( + self, dim, seq_len, mlp_ratio=4, mlp_layer=Mlp, norm_layer=Affine, + act_layer=nn.GELU, init_values=1e-4, drop=0., drop_path=0.): + super().__init__() + channel_dim = int(dim * mlp_ratio) + self.norm1 = norm_layer(dim) + self.linear_tokens = nn.Linear(seq_len, seq_len) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop) + self.ls1 = nn.Parameter(init_values * torch.ones(dim)) + self.ls2 = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + x = x + self.drop_path(self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) + x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x))) + return x + + +class SpatialGatingUnit(nn.Module): + """ Spatial Gating Unit + + Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm): + super().__init__() + gate_dim = dim // 2 + self.norm = norm_layer(gate_dim) + self.proj = nn.Linear(seq_len, seq_len) + + def forward(self, x): + u, v = x.chunk(2, dim=-1) + v = self.norm(v) + v = self.proj(v.transpose(-1, -2)) + return u * v.transpose(-1, -2) + + +class SpatialGatingBlock(nn.Module): + """ Residual Block w/ Spatial Gating + + Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + def __init__( + self, dim, seq_len, mlp_ratio=4, mlp_layer=GatedMlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): + super().__init__() + channel_dim = int(dim * mlp_ratio) + self.norm = norm_layer(dim) + sgu = partial(SpatialGatingUnit, seq_len=seq_len) + self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + x = x + self.drop_path(self.mlp_channels(self.norm(x))) + return x + + class MlpMixer(nn.Module): def __init__( @@ -91,24 +197,27 @@ class MlpMixer(nn.Module): patch_size=16, num_blocks=8, hidden_dim=512, - tokens_dim=256, - channels_dim=2048, + mlp_ratio=(0.5, 4.0), + block_layer=MixerBlock, mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop_rate=0., drop_path_rate=0., nlhb=False, + stem_norm=False, ): super().__init__() self.num_classes = num_classes - self.stem = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim) - # FIXME drop_path (stochastic depth scaling rule?) + self.stem = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim, + norm_layer=norm_layer if stem_norm else None) + # FIXME drop_path (stochastic depth scaling rule or all the same?) self.blocks = nn.Sequential(*[ - MixerBlock( - hidden_dim, self.stem.num_patches, tokens_dim, channels_dim, - mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) + block_layer( + hidden_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, + act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) for _ in range(num_blocks)]) self.norm = norm_layer(hidden_dim) self.head = nn.Linear(hidden_dim, self.num_classes) # zero init @@ -136,6 +245,9 @@ def _init_weights(m, n: str, head_bias: float = 0.): if n.startswith('head'): nn.init.zeros_(m.weight) nn.init.constant_(m.bias, head_bias) + elif n.endswith('gate.proj'): + nn.init.normal_(m.weight, std=1e-4) + nn.init.ones_(m.bias) else: nn.init.xavier_uniform_(m.weight) if m.bias is not None: @@ -177,8 +289,9 @@ def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs): @register_model def mixer_s32_224(pretrained=False, **kwargs): """ Mixer-S/32 224x224 + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs) + model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, **kwargs) model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args) return model @@ -186,28 +299,19 @@ def mixer_s32_224(pretrained=False, **kwargs): @register_model def mixer_s16_224(pretrained=False, **kwargs): """ Mixer-S/16 224x224 + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs) + model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, **kwargs) model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args) return model -@register_model -def mixer_s16_glu_224(pretrained=False, **kwargs): - """ Mixer-S/16 224x224 - """ - model_args = dict( - patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=1536, - mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) - model = _create_mixer('mixer_s16_glu_224', pretrained=pretrained, **model_args) - return model - - @register_model def mixer_b32_224(pretrained=False, **kwargs): """ Mixer-B/32 224x224 + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) + model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, **kwargs) model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args) return model @@ -215,8 +319,9 @@ def mixer_b32_224(pretrained=False, **kwargs): @register_model def mixer_b16_224(pretrained=False, **kwargs): """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) + model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args) return model @@ -224,8 +329,9 @@ def mixer_b16_224(pretrained=False, **kwargs): @register_model def mixer_b16_224_in21k(pretrained=False, **kwargs): """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) + model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args) return model @@ -233,8 +339,9 @@ def mixer_b16_224_in21k(pretrained=False, **kwargs): @register_model def mixer_l32_224(pretrained=False, **kwargs): """ Mixer-L/32 224x224. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs) + model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, **kwargs) model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args) return model @@ -242,8 +349,9 @@ def mixer_l32_224(pretrained=False, **kwargs): @register_model def mixer_l16_224(pretrained=False, **kwargs): """ Mixer-L/16 224x224. ImageNet-1k pretrained weights. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs) + model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs) model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args) return model @@ -251,7 +359,101 @@ def mixer_l16_224(pretrained=False, **kwargs): @register_model def mixer_l16_224_in21k(pretrained=False, **kwargs): """ Mixer-L/16 224x224. ImageNet-21k pretrained weights. + Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs) + model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs) model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args) return model + + +@register_model +def gmixer_12_224(pretrained=False, **kwargs): + """ Glu-Mixer-12 224x224 (short & fat) + Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer + """ + model_args = dict( + patch_size=20, num_blocks=12, hidden_dim=512, mlp_ratio=(1.0, 6.0), + mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) + model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmixer_24_224(pretrained=False, **kwargs): + """ Glu-Mixer-24 224x224 (tall & slim) + Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer + """ + model_args = dict( + patch_size=20, num_blocks=24, hidden_dim=384, mlp_ratio=(1.0, 6.0), + mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) + model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_12_224(pretrained=False, **kwargs): + """ ResMLP-12 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=12, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_24_224(pretrained=False, **kwargs): + """ ResMLP-24 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def resmlp_36_224(pretrained=False, **kwargs): + """ ResMLP-36 + Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 + """ + model_args = dict( + patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmlp_ti16_224(pretrained=False, **kwargs): + """ gMLP-Tiny + Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + model_args = dict( + patch_size=16, num_blocks=30, hidden_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock, + mlp_layer=GatedMlp, **kwargs) + model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmlp_s16_224(pretrained=False, **kwargs): + """ gMLP-Small + Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + model_args = dict( + patch_size=16, num_blocks=30, hidden_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock, + mlp_layer=GatedMlp, **kwargs) + model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args) + return model + + +@register_model +def gmlp_b16_224(pretrained=False, **kwargs): + """ gMLP-Base + Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 + """ + model_args = dict( + patch_size=16, num_blocks=30, hidden_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock, + mlp_layer=GatedMlp, **kwargs) + model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args) + return model