Add preliminary gMLP and ResMLP impl to Mlp-Mixer

pull/647/head
Ross Wightman 4 years ago
parent e7f0db8664
commit d5af752117

@ -20,7 +20,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
from .inplace_abn import InplaceAbn from .inplace_abn import InplaceAbn
from .linear import Linear from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp from .mlp import Mlp, GluMlp, GatedMlp
from .norm import GroupNorm from .norm import GroupNorm
from .norm_act import BatchNormAct2d, GroupNormAct from .norm_act import BatchNormAct2d, GroupNormAct
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same

@ -34,9 +34,10 @@ class GluMlp(nn.Module):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features * 2) assert hidden_features % 2 == 0
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer() self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features) self.fc2 = nn.Linear(hidden_features // 2, out_features)
self.drop = nn.Dropout(drop) self.drop = nn.Dropout(drop)
def forward(self, x): def forward(self, x):
@ -47,3 +48,32 @@ class GluMlp(nn.Module):
x = self.fc2(x) x = self.fc2(x)
x = self.drop(x) x = self.drop(x)
return x return x
class GatedMlp(nn.Module):
""" MLP as used in gMLP
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
gate_layer=None, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
if gate_layer is not None:
assert hidden_features % 2 == 0
self.gate = gate_layer(hidden_features)
hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
else:
self.gate = nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.gate(x)
x = self.fc2(x)
x = self.drop(x)
return x

@ -1,4 +1,6 @@
""" MLP-Mixer in PyTorch """ MLP-Mixer, ResMLP, and gMLP in PyTorch
This impl originally based on MLP-Mixer paper.
Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
@ -12,6 +14,25 @@ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2
year={2021} year={2021}
} }
Also supporting preliminary (not verified) implementations of ResMlp, gMLP, and possibly more...
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
@misc{touvron2021resmlp,
title={ResMLP: Feedforward networks for image classification with data-efficient training},
author={Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and
Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
year={2021},
eprint={2105.03404},
}
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
@misc{liu2021pay,
title={Pay Attention to MLPs},
author={Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
year={2021},
eprint={2105.08050},
}
A thank you to paper authors for releasing code and weights. A thank you to paper authors for releasing code and weights.
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2021 Ross Wightman
@ -25,7 +46,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, GluMlp, DropPath, lecun_normal_ from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
from .registry import register_model from .registry import register_model
@ -43,7 +64,6 @@ def _cfg(url='', **kwargs):
default_cfgs = dict( default_cfgs = dict(
mixer_s32_224=_cfg(), mixer_s32_224=_cfg(),
mixer_s16_224=_cfg(), mixer_s16_224=_cfg(),
mixer_s16_glu_224=_cfg(),
mixer_b32_224=_cfg(), mixer_b32_224=_cfg(),
mixer_b16_224=_cfg( mixer_b16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
@ -60,15 +80,29 @@ default_cfgs = dict(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
num_classes=21843 num_classes=21843
), ),
gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
gmixer_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_36_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
gmlp_ti16_224=_cfg(),
gmlp_s16_224=_cfg(),
gmlp_b16_224=_cfg(),
) )
class MixerBlock(nn.Module): class MixerBlock(nn.Module):
""" Residual Block w/ token mixing and channel MLPs
Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
def __init__( def __init__(
self, dim, seq_len, tokens_dim, channels_dim, self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp,
mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
super().__init__() super().__init__()
tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop) self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -81,6 +115,78 @@ class MixerBlock(nn.Module):
return x return x
class Affine(nn.Module):
def __init__(self, dim):
super().__init__()
self.alpha = nn.Parameter(torch.ones((1, 1, dim)))
self.beta = nn.Parameter(torch.zeros((1, 1, dim)))
def forward(self, x):
return torch.addcmul(self.beta, self.alpha, x)
class ResBlock(nn.Module):
""" Residual MLP block w/ LayerScale and Affine 'norm'
Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
def __init__(
self, dim, seq_len, mlp_ratio=4, mlp_layer=Mlp, norm_layer=Affine,
act_layer=nn.GELU, init_values=1e-4, drop=0., drop_path=0.):
super().__init__()
channel_dim = int(dim * mlp_ratio)
self.norm1 = norm_layer(dim)
self.linear_tokens = nn.Linear(seq_len, seq_len)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop)
self.ls1 = nn.Parameter(init_values * torch.ones(dim))
self.ls2 = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
x = x + self.drop_path(self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x)))
return x
class SpatialGatingUnit(nn.Module):
""" Spatial Gating Unit
Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm):
super().__init__()
gate_dim = dim // 2
self.norm = norm_layer(gate_dim)
self.proj = nn.Linear(seq_len, seq_len)
def forward(self, x):
u, v = x.chunk(2, dim=-1)
v = self.norm(v)
v = self.proj(v.transpose(-1, -2))
return u * v.transpose(-1, -2)
class SpatialGatingBlock(nn.Module):
""" Residual Block w/ Spatial Gating
Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
def __init__(
self, dim, seq_len, mlp_ratio=4, mlp_layer=GatedMlp,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
super().__init__()
channel_dim = int(dim * mlp_ratio)
self.norm = norm_layer(dim)
sgu = partial(SpatialGatingUnit, seq_len=seq_len)
self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.mlp_channels(self.norm(x)))
return x
class MlpMixer(nn.Module): class MlpMixer(nn.Module):
def __init__( def __init__(
@ -91,24 +197,27 @@ class MlpMixer(nn.Module):
patch_size=16, patch_size=16,
num_blocks=8, num_blocks=8,
hidden_dim=512, hidden_dim=512,
tokens_dim=256, mlp_ratio=(0.5, 4.0),
channels_dim=2048, block_layer=MixerBlock,
mlp_layer=Mlp, mlp_layer=Mlp,
norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU, act_layer=nn.GELU,
drop_rate=0., drop_rate=0.,
drop_path_rate=0., drop_path_rate=0.,
nlhb=False, nlhb=False,
stem_norm=False,
): ):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.stem = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim) self.stem = PatchEmbed(
# FIXME drop_path (stochastic depth scaling rule?) img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim,
norm_layer=norm_layer if stem_norm else None)
# FIXME drop_path (stochastic depth scaling rule or all the same?)
self.blocks = nn.Sequential(*[ self.blocks = nn.Sequential(*[
MixerBlock( block_layer(
hidden_dim, self.stem.num_patches, tokens_dim, channels_dim, hidden_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
for _ in range(num_blocks)]) for _ in range(num_blocks)])
self.norm = norm_layer(hidden_dim) self.norm = norm_layer(hidden_dim)
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
@ -136,6 +245,9 @@ def _init_weights(m, n: str, head_bias: float = 0.):
if n.startswith('head'): if n.startswith('head'):
nn.init.zeros_(m.weight) nn.init.zeros_(m.weight)
nn.init.constant_(m.bias, head_bias) nn.init.constant_(m.bias, head_bias)
elif n.endswith('gate.proj'):
nn.init.normal_(m.weight, std=1e-4)
nn.init.ones_(m.bias)
else: else:
nn.init.xavier_uniform_(m.weight) nn.init.xavier_uniform_(m.weight)
if m.bias is not None: if m.bias is not None:
@ -177,8 +289,9 @@ def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs):
@register_model @register_model
def mixer_s32_224(pretrained=False, **kwargs): def mixer_s32_224(pretrained=False, **kwargs):
""" Mixer-S/32 224x224 """ Mixer-S/32 224x224
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs) model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, **kwargs)
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args) model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
return model return model
@ -186,28 +299,19 @@ def mixer_s32_224(pretrained=False, **kwargs):
@register_model @register_model
def mixer_s16_224(pretrained=False, **kwargs): def mixer_s16_224(pretrained=False, **kwargs):
""" Mixer-S/16 224x224 """ Mixer-S/16 224x224
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs) model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, **kwargs)
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args) model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
return model return model
@register_model
def mixer_s16_glu_224(pretrained=False, **kwargs):
""" Mixer-S/16 224x224
"""
model_args = dict(
patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=1536,
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
model = _create_mixer('mixer_s16_glu_224', pretrained=pretrained, **model_args)
return model
@register_model @register_model
def mixer_b32_224(pretrained=False, **kwargs): def mixer_b32_224(pretrained=False, **kwargs):
""" Mixer-B/32 224x224 """ Mixer-B/32 224x224
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, **kwargs)
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args) model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
return model return model
@ -215,8 +319,9 @@ def mixer_b32_224(pretrained=False, **kwargs):
@register_model @register_model
def mixer_b16_224(pretrained=False, **kwargs): def mixer_b16_224(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights. """ Mixer-B/16 224x224. ImageNet-1k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args) model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
return model return model
@ -224,8 +329,9 @@ def mixer_b16_224(pretrained=False, **kwargs):
@register_model @register_model
def mixer_b16_224_in21k(pretrained=False, **kwargs): def mixer_b16_224_in21k(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights. """ Mixer-B/16 224x224. ImageNet-21k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs) model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args) model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
return model return model
@ -233,8 +339,9 @@ def mixer_b16_224_in21k(pretrained=False, **kwargs):
@register_model @register_model
def mixer_l32_224(pretrained=False, **kwargs): def mixer_l32_224(pretrained=False, **kwargs):
""" Mixer-L/32 224x224. """ Mixer-L/32 224x224.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs) model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, **kwargs)
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args) model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
return model return model
@ -242,8 +349,9 @@ def mixer_l32_224(pretrained=False, **kwargs):
@register_model @register_model
def mixer_l16_224(pretrained=False, **kwargs): def mixer_l16_224(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights. """ Mixer-L/16 224x224. ImageNet-1k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs) model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs)
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args) model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
return model return model
@ -251,7 +359,101 @@ def mixer_l16_224(pretrained=False, **kwargs):
@register_model @register_model
def mixer_l16_224_in21k(pretrained=False, **kwargs): def mixer_l16_224_in21k(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights. """ Mixer-L/16 224x224. ImageNet-21k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs) model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs)
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args) model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
return model return model
@register_model
def gmixer_12_224(pretrained=False, **kwargs):
""" Glu-Mixer-12 224x224 (short & fat)
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
"""
model_args = dict(
patch_size=20, num_blocks=12, hidden_dim=512, mlp_ratio=(1.0, 6.0),
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args)
return model
@register_model
def gmixer_24_224(pretrained=False, **kwargs):
""" Glu-Mixer-24 224x224 (tall & slim)
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
"""
model_args = dict(
patch_size=20, num_blocks=24, hidden_dim=384, mlp_ratio=(1.0, 6.0),
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_12_224(pretrained=False, **kwargs):
""" ResMLP-12
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=12, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_24_224(pretrained=False, **kwargs):
""" ResMLP-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_36_224(pretrained=False, **kwargs):
""" ResMLP-36
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
return model
@register_model
def gmlp_ti16_224(pretrained=False, **kwargs):
""" gMLP-Tiny
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
model_args = dict(
patch_size=16, num_blocks=30, hidden_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp, **kwargs)
model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args)
return model
@register_model
def gmlp_s16_224(pretrained=False, **kwargs):
""" gMLP-Small
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
model_args = dict(
patch_size=16, num_blocks=30, hidden_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp, **kwargs)
model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args)
return model
@register_model
def gmlp_b16_224(pretrained=False, **kwargs):
""" gMLP-Base
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
model_args = dict(
patch_size=16, num_blocks=30, hidden_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp, **kwargs)
model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args)
return model

Loading…
Cancel
Save