|
|
@ -1,4 +1,6 @@
|
|
|
|
""" MLP-Mixer in PyTorch
|
|
|
|
""" MLP-Mixer, ResMLP, and gMLP in PyTorch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
This impl originally based on MLP-Mixer paper.
|
|
|
|
|
|
|
|
|
|
|
|
Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
|
|
|
|
Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
|
|
|
|
|
|
|
|
|
|
|
@ -12,6 +14,25 @@ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2
|
|
|
|
year={2021}
|
|
|
|
year={2021}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Also supporting preliminary (not verified) implementations of ResMlp, gMLP, and possibly more...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
|
|
|
|
|
|
|
@misc{touvron2021resmlp,
|
|
|
|
|
|
|
|
title={ResMLP: Feedforward networks for image classification with data-efficient training},
|
|
|
|
|
|
|
|
author={Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and
|
|
|
|
|
|
|
|
Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
|
|
|
|
|
|
|
|
year={2021},
|
|
|
|
|
|
|
|
eprint={2105.03404},
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
|
|
|
|
|
|
|
@misc{liu2021pay,
|
|
|
|
|
|
|
|
title={Pay Attention to MLPs},
|
|
|
|
|
|
|
|
author={Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
|
|
|
|
|
|
|
|
year={2021},
|
|
|
|
|
|
|
|
eprint={2105.08050},
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
A thank you to paper authors for releasing code and weights.
|
|
|
|
A thank you to paper authors for releasing code and weights.
|
|
|
|
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2021 Ross Wightman
|
|
|
|
Hacked together by / Copyright 2021 Ross Wightman
|
|
|
@ -25,7 +46,7 @@ import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
|
|
|
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
|
|
|
from .layers import PatchEmbed, Mlp, GluMlp, DropPath, lecun_normal_
|
|
|
|
from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
|
|
|
|
from .registry import register_model
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -43,7 +64,6 @@ def _cfg(url='', **kwargs):
|
|
|
|
default_cfgs = dict(
|
|
|
|
default_cfgs = dict(
|
|
|
|
mixer_s32_224=_cfg(),
|
|
|
|
mixer_s32_224=_cfg(),
|
|
|
|
mixer_s16_224=_cfg(),
|
|
|
|
mixer_s16_224=_cfg(),
|
|
|
|
mixer_s16_glu_224=_cfg(),
|
|
|
|
|
|
|
|
mixer_b32_224=_cfg(),
|
|
|
|
mixer_b32_224=_cfg(),
|
|
|
|
mixer_b16_224=_cfg(
|
|
|
|
mixer_b16_224=_cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
|
|
|
@ -60,15 +80,29 @@ default_cfgs = dict(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
|
|
|
|
num_classes=21843
|
|
|
|
num_classes=21843
|
|
|
|
),
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
|
|
|
gmixer_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resmlp_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
|
|
|
resmlp_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
|
|
|
resmlp_36_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gmlp_ti16_224=_cfg(),
|
|
|
|
|
|
|
|
gmlp_s16_224=_cfg(),
|
|
|
|
|
|
|
|
gmlp_b16_224=_cfg(),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MixerBlock(nn.Module):
|
|
|
|
class MixerBlock(nn.Module):
|
|
|
|
|
|
|
|
""" Residual Block w/ token mixing and channel MLPs
|
|
|
|
|
|
|
|
Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
|
|
|
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, dim, seq_len, tokens_dim, channels_dim,
|
|
|
|
self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp,
|
|
|
|
mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
|
|
|
|
self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
@ -81,6 +115,78 @@ class MixerBlock(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Affine(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, dim):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.alpha = nn.Parameter(torch.ones((1, 1, dim)))
|
|
|
|
|
|
|
|
self.beta = nn.Parameter(torch.zeros((1, 1, dim)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
return torch.addcmul(self.beta, self.alpha, x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResBlock(nn.Module):
|
|
|
|
|
|
|
|
""" Residual MLP block w/ LayerScale and Affine 'norm'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self, dim, seq_len, mlp_ratio=4, mlp_layer=Mlp, norm_layer=Affine,
|
|
|
|
|
|
|
|
act_layer=nn.GELU, init_values=1e-4, drop=0., drop_path=0.):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
channel_dim = int(dim * mlp_ratio)
|
|
|
|
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
|
|
|
|
self.linear_tokens = nn.Linear(seq_len, seq_len)
|
|
|
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
|
|
|
|
self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop)
|
|
|
|
|
|
|
|
self.ls1 = nn.Parameter(init_values * torch.ones(dim))
|
|
|
|
|
|
|
|
self.ls2 = nn.Parameter(init_values * torch.ones(dim))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
x = x + self.drop_path(self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
|
|
|
|
|
|
|
|
x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x)))
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpatialGatingUnit(nn.Module):
|
|
|
|
|
|
|
|
""" Spatial Gating Unit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
gate_dim = dim // 2
|
|
|
|
|
|
|
|
self.norm = norm_layer(gate_dim)
|
|
|
|
|
|
|
|
self.proj = nn.Linear(seq_len, seq_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
u, v = x.chunk(2, dim=-1)
|
|
|
|
|
|
|
|
v = self.norm(v)
|
|
|
|
|
|
|
|
v = self.proj(v.transpose(-1, -2))
|
|
|
|
|
|
|
|
return u * v.transpose(-1, -2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpatialGatingBlock(nn.Module):
|
|
|
|
|
|
|
|
""" Residual Block w/ Spatial Gating
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self, dim, seq_len, mlp_ratio=4, mlp_layer=GatedMlp,
|
|
|
|
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
channel_dim = int(dim * mlp_ratio)
|
|
|
|
|
|
|
|
self.norm = norm_layer(dim)
|
|
|
|
|
|
|
|
sgu = partial(SpatialGatingUnit, seq_len=seq_len)
|
|
|
|
|
|
|
|
self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop)
|
|
|
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
x = x + self.drop_path(self.mlp_channels(self.norm(x)))
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MlpMixer(nn.Module):
|
|
|
|
class MlpMixer(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
@ -91,24 +197,27 @@ class MlpMixer(nn.Module):
|
|
|
|
patch_size=16,
|
|
|
|
patch_size=16,
|
|
|
|
num_blocks=8,
|
|
|
|
num_blocks=8,
|
|
|
|
hidden_dim=512,
|
|
|
|
hidden_dim=512,
|
|
|
|
tokens_dim=256,
|
|
|
|
mlp_ratio=(0.5, 4.0),
|
|
|
|
channels_dim=2048,
|
|
|
|
block_layer=MixerBlock,
|
|
|
|
mlp_layer=Mlp,
|
|
|
|
mlp_layer=Mlp,
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
act_layer=nn.GELU,
|
|
|
|
drop_rate=0.,
|
|
|
|
drop_rate=0.,
|
|
|
|
drop_path_rate=0.,
|
|
|
|
drop_path_rate=0.,
|
|
|
|
nlhb=False,
|
|
|
|
nlhb=False,
|
|
|
|
|
|
|
|
stem_norm=False,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
|
|
|
|
|
|
|
self.stem = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim)
|
|
|
|
self.stem = PatchEmbed(
|
|
|
|
# FIXME drop_path (stochastic depth scaling rule?)
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim,
|
|
|
|
|
|
|
|
norm_layer=norm_layer if stem_norm else None)
|
|
|
|
|
|
|
|
# FIXME drop_path (stochastic depth scaling rule or all the same?)
|
|
|
|
self.blocks = nn.Sequential(*[
|
|
|
|
self.blocks = nn.Sequential(*[
|
|
|
|
MixerBlock(
|
|
|
|
block_layer(
|
|
|
|
hidden_dim, self.stem.num_patches, tokens_dim, channels_dim,
|
|
|
|
hidden_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
|
|
|
|
mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
|
|
|
|
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
|
|
|
|
for _ in range(num_blocks)])
|
|
|
|
for _ in range(num_blocks)])
|
|
|
|
self.norm = norm_layer(hidden_dim)
|
|
|
|
self.norm = norm_layer(hidden_dim)
|
|
|
|
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
|
|
|
|
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
|
|
|
@ -136,6 +245,9 @@ def _init_weights(m, n: str, head_bias: float = 0.):
|
|
|
|
if n.startswith('head'):
|
|
|
|
if n.startswith('head'):
|
|
|
|
nn.init.zeros_(m.weight)
|
|
|
|
nn.init.zeros_(m.weight)
|
|
|
|
nn.init.constant_(m.bias, head_bias)
|
|
|
|
nn.init.constant_(m.bias, head_bias)
|
|
|
|
|
|
|
|
elif n.endswith('gate.proj'):
|
|
|
|
|
|
|
|
nn.init.normal_(m.weight, std=1e-4)
|
|
|
|
|
|
|
|
nn.init.ones_(m.bias)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
nn.init.xavier_uniform_(m.weight)
|
|
|
|
nn.init.xavier_uniform_(m.weight)
|
|
|
|
if m.bias is not None:
|
|
|
|
if m.bias is not None:
|
|
|
@ -177,8 +289,9 @@ def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs):
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def mixer_s32_224(pretrained=False, **kwargs):
|
|
|
|
def mixer_s32_224(pretrained=False, **kwargs):
|
|
|
|
""" Mixer-S/32 224x224
|
|
|
|
""" Mixer-S/32 224x224
|
|
|
|
|
|
|
|
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs)
|
|
|
|
model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, **kwargs)
|
|
|
|
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -186,28 +299,19 @@ def mixer_s32_224(pretrained=False, **kwargs):
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def mixer_s16_224(pretrained=False, **kwargs):
|
|
|
|
def mixer_s16_224(pretrained=False, **kwargs):
|
|
|
|
""" Mixer-S/16 224x224
|
|
|
|
""" Mixer-S/16 224x224
|
|
|
|
|
|
|
|
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs)
|
|
|
|
model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, **kwargs)
|
|
|
|
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def mixer_s16_glu_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" Mixer-S/16 224x224
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_args = dict(
|
|
|
|
|
|
|
|
patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=1536,
|
|
|
|
|
|
|
|
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
|
|
|
|
|
|
|
|
model = _create_mixer('mixer_s16_glu_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def mixer_b32_224(pretrained=False, **kwargs):
|
|
|
|
def mixer_b32_224(pretrained=False, **kwargs):
|
|
|
|
""" Mixer-B/32 224x224
|
|
|
|
""" Mixer-B/32 224x224
|
|
|
|
|
|
|
|
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
|
|
|
|
model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, **kwargs)
|
|
|
|
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -215,8 +319,9 @@ def mixer_b32_224(pretrained=False, **kwargs):
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def mixer_b16_224(pretrained=False, **kwargs):
|
|
|
|
def mixer_b16_224(pretrained=False, **kwargs):
|
|
|
|
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
|
|
|
|
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
|
|
|
|
|
|
|
|
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
|
|
|
|
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
|
|
|
|
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -224,8 +329,9 @@ def mixer_b16_224(pretrained=False, **kwargs):
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def mixer_b16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def mixer_b16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
|
|
|
|
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
|
|
|
|
|
|
|
|
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
|
|
|
|
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
|
|
|
|
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -233,8 +339,9 @@ def mixer_b16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def mixer_l32_224(pretrained=False, **kwargs):
|
|
|
|
def mixer_l32_224(pretrained=False, **kwargs):
|
|
|
|
""" Mixer-L/32 224x224.
|
|
|
|
""" Mixer-L/32 224x224.
|
|
|
|
|
|
|
|
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
|
|
|
|
model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, **kwargs)
|
|
|
|
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -242,8 +349,9 @@ def mixer_l32_224(pretrained=False, **kwargs):
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def mixer_l16_224(pretrained=False, **kwargs):
|
|
|
|
def mixer_l16_224(pretrained=False, **kwargs):
|
|
|
|
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights.
|
|
|
|
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights.
|
|
|
|
|
|
|
|
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
|
|
|
|
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs)
|
|
|
|
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -251,7 +359,101 @@ def mixer_l16_224(pretrained=False, **kwargs):
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def mixer_l16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
def mixer_l16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights.
|
|
|
|
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights.
|
|
|
|
|
|
|
|
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
|
|
|
|
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs)
|
|
|
|
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def gmixer_12_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" Glu-Mixer-12 224x224 (short & fat)
|
|
|
|
|
|
|
|
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_args = dict(
|
|
|
|
|
|
|
|
patch_size=20, num_blocks=12, hidden_dim=512, mlp_ratio=(1.0, 6.0),
|
|
|
|
|
|
|
|
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
|
|
|
|
|
|
|
|
model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def gmixer_24_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" Glu-Mixer-24 224x224 (tall & slim)
|
|
|
|
|
|
|
|
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_args = dict(
|
|
|
|
|
|
|
|
patch_size=20, num_blocks=24, hidden_dim=384, mlp_ratio=(1.0, 6.0),
|
|
|
|
|
|
|
|
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
|
|
|
|
|
|
|
|
model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def resmlp_12_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ResMLP-12
|
|
|
|
|
|
|
|
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_args = dict(
|
|
|
|
|
|
|
|
patch_size=16, num_blocks=12, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
|
|
|
|
|
|
|
model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def resmlp_24_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ResMLP-24
|
|
|
|
|
|
|
|
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_args = dict(
|
|
|
|
|
|
|
|
patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
|
|
|
|
|
|
|
model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def resmlp_36_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ResMLP-36
|
|
|
|
|
|
|
|
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_args = dict(
|
|
|
|
|
|
|
|
patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
|
|
|
|
|
|
|
model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def gmlp_ti16_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" gMLP-Tiny
|
|
|
|
|
|
|
|
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_args = dict(
|
|
|
|
|
|
|
|
patch_size=16, num_blocks=30, hidden_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
|
|
|
|
|
|
|
mlp_layer=GatedMlp, **kwargs)
|
|
|
|
|
|
|
|
model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def gmlp_s16_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" gMLP-Small
|
|
|
|
|
|
|
|
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_args = dict(
|
|
|
|
|
|
|
|
patch_size=16, num_blocks=30, hidden_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
|
|
|
|
|
|
|
mlp_layer=GatedMlp, **kwargs)
|
|
|
|
|
|
|
|
model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def gmlp_b16_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" gMLP-Base
|
|
|
|
|
|
|
|
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_args = dict(
|
|
|
|
|
|
|
|
patch_size=16, num_blocks=30, hidden_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
|
|
|
|
|
|
|
|
mlp_layer=GatedMlp, **kwargs)
|
|
|
|
|
|
|
|
model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|