You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/mlp_mixer.py

684 lines
26 KiB

""" MLP-Mixer, ResMLP, and gMLP in PyTorch
This impl originally based on MLP-Mixer paper.
Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
@article{tolstikhin2021,
title={MLP-Mixer: An all-MLP Architecture for Vision},
author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner,
Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
journal={arXiv preprint arXiv:2105.01601},
year={2021}
}
Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP
Code: https://github.com/facebookresearch/deit
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
@misc{touvron2021resmlp,
title={ResMLP: Feedforward networks for image classification with data-efficient training},
author={Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and
Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
year={2021},
eprint={2105.03404},
}
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
@misc{liu2021pay,
title={Pay Attention to MLPs},
author={Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
year={2021},
eprint={2105.08050},
}
A thank you to paper authors for releasing code and weights.
Hacked together by / Copyright 2021 Ross Wightman
"""
import math
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq
from ._registry import register_model
__all__ = ['MixerBlock'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'stem.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = dict(
mixer_s32_224=_cfg(),
mixer_s16_224=_cfg(),
mixer_b32_224=_cfg(),
mixer_b16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
),
mixer_b16_224_in21k=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
num_classes=21843
),
mixer_l32_224=_cfg(),
mixer_l16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
),
mixer_l16_224_in21k=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
num_classes=21843
),
# Mixer ImageNet-21K-P pretraining
mixer_b16_224_miil_in21k=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mixer_b16_224_miil_in21k-2a558a71.pth',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
),
mixer_b16_224_miil=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mixer_b16_224_miil-9229a591.pth',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear',
),
gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
gmixer_24_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_12_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_24_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth',
#url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_36_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_big_24_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_12_distilled_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_24_distilled_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_36_distilled_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_big_24_distilled_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_big_24_224_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_12_224_dino=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dino.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_24_224_dino=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
gmlp_ti16_224=_cfg(),
gmlp_s16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth',
),
gmlp_b16_224=_cfg(),
)
class MixerBlock(nn.Module):
""" Residual Block w/ token mixing and channel MLPs
Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
def __init__(
self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
super().__init__()
tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
self.norm1 = norm_layer(dim)
self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
return x
class Affine(nn.Module):
def __init__(self, dim):
super().__init__()
self.alpha = nn.Parameter(torch.ones((1, 1, dim)))
self.beta = nn.Parameter(torch.zeros((1, 1, dim)))
def forward(self, x):
return torch.addcmul(self.beta, self.alpha, x)
class ResBlock(nn.Module):
""" Residual MLP block w/ LayerScale and Affine 'norm'
Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
def __init__(
self, dim, seq_len, mlp_ratio=4, mlp_layer=Mlp, norm_layer=Affine,
act_layer=nn.GELU, init_values=1e-4, drop=0., drop_path=0.):
super().__init__()
channel_dim = int(dim * mlp_ratio)
self.norm1 = norm_layer(dim)
self.linear_tokens = nn.Linear(seq_len, seq_len)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop)
self.ls1 = nn.Parameter(init_values * torch.ones(dim))
self.ls2 = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
x = x + self.drop_path(self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x)))
return x
class SpatialGatingUnit(nn.Module):
""" Spatial Gating Unit
Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm):
super().__init__()
gate_dim = dim // 2
self.norm = norm_layer(gate_dim)
self.proj = nn.Linear(seq_len, seq_len)
def init_weights(self):
# special init for the projection gate, called as override by base model init
nn.init.normal_(self.proj.weight, std=1e-6)
nn.init.ones_(self.proj.bias)
def forward(self, x):
u, v = x.chunk(2, dim=-1)
v = self.norm(v)
v = self.proj(v.transpose(-1, -2))
return u * v.transpose(-1, -2)
class SpatialGatingBlock(nn.Module):
""" Residual Block w/ Spatial Gating
Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
def __init__(
self, dim, seq_len, mlp_ratio=4, mlp_layer=GatedMlp,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
super().__init__()
channel_dim = int(dim * mlp_ratio)
self.norm = norm_layer(dim)
sgu = partial(SpatialGatingUnit, seq_len=seq_len)
self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.mlp_channels(self.norm(x)))
return x
class MlpMixer(nn.Module):
def __init__(
self,
num_classes=1000,
img_size=224,
in_chans=3,
patch_size=16,
num_blocks=8,
embed_dim=512,
mlp_ratio=(0.5, 4.0),
block_layer=MixerBlock,
mlp_layer=Mlp,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
drop_rate=0.,
drop_path_rate=0.,
nlhb=False,
stem_norm=False,
global_pool='avg',
):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.grad_checkpointing = False
self.stem = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None)
# FIXME drop_path (stochastic depth scaling rule or all the same?)
self.blocks = nn.Sequential(*[
block_layer(
embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
for _ in range(num_blocks)])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(nlhb=nlhb)
@torch.jit.ignore
def init_weights(self, nlhb=False):
head_bias = -math.log(self.num_classes) if nlhb else 0.
named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
if self.global_pool == 'avg':
x = x.mean(dim=1)
x = self.head(x)
return x
def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
""" Mixer weight initialization (trying to match Flax defaults)
"""
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
if flax:
# Flax defaults
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
else:
# like MLP init in vit (my original init)
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
# NOTE if a parent module contains init_weights method, it can override the init of the
# child modules as this will be called in depth-first order.
module.init_weights()
def checkpoint_filter_fn(state_dict, model):
""" Remap checkpoints if needed """
if 'patch_embed.proj.weight' in state_dict:
# Remap FB ResMlp models -> timm
out_dict = {}
for k, v in state_dict.items():
k = k.replace('patch_embed.', 'stem.')
k = k.replace('attn.', 'linear_tokens.')
k = k.replace('mlp.', 'mlp_channels.')
k = k.replace('gamma_', 'ls')
if k.endswith('.alpha') or k.endswith('.beta'):
v = v.reshape(1, 1, -1)
out_dict[k] = v
return out_dict
return state_dict
def _create_mixer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
model = build_model_with_cfg(
MlpMixer, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model
def mixer_s32_224(pretrained=False, **kwargs):
""" Mixer-S/32 224x224
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs)
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_s16_224(pretrained=False, **kwargs):
""" Mixer-S/16 224x224
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs)
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b32_224(pretrained=False, **kwargs):
""" Mixer-B/32 224x224
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224_in21k(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_l32_224(pretrained=False, **kwargs):
""" Mixer-L/32 224x224.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs)
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_l16_224(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_l16_224_in21k(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
"""
model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224_miil(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224_miil_in21k(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args)
return model
@register_model
def gmixer_12_224(pretrained=False, **kwargs):
""" Glu-Mixer-12 224x224
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
"""
model_args = dict(
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=(1.0, 4.0),
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args)
return model
@register_model
def gmixer_24_224(pretrained=False, **kwargs):
""" Glu-Mixer-24 224x224
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
"""
model_args = dict(
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 4.0),
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_12_224(pretrained=False, **kwargs):
""" ResMLP-12
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_24_224(pretrained=False, **kwargs):
""" ResMLP-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_36_224(pretrained=False, **kwargs):
""" ResMLP-36
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_big_24_224(pretrained=False, **kwargs):
""" ResMLP-B-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_12_distilled_224(pretrained=False, **kwargs):
""" ResMLP-12
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_12_distilled_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_24_distilled_224(pretrained=False, **kwargs):
""" ResMLP-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_24_distilled_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_36_distilled_224(pretrained=False, **kwargs):
""" ResMLP-36
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_36_distilled_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_big_24_distilled_224(pretrained=False, **kwargs):
""" ResMLP-B-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_big_24_distilled_224', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs):
""" ResMLP-B-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
"""
model_args = dict(
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_big_24_224_in22ft1k', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_12_224_dino(pretrained=False, **kwargs):
""" ResMLP-12
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
Model pretrained via DINO (self-supervised) - https://arxiv.org/abs/2104.14294
"""
model_args = dict(
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_12_224_dino', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_24_224_dino(pretrained=False, **kwargs):
""" ResMLP-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
Model pretrained via DINO (self-supervised) - https://arxiv.org/abs/2104.14294
"""
model_args = dict(
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_24_224_dino', pretrained=pretrained, **model_args)
return model
@register_model
def gmlp_ti16_224(pretrained=False, **kwargs):
""" gMLP-Tiny
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
model_args = dict(
patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp, **kwargs)
model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args)
return model
@register_model
def gmlp_s16_224(pretrained=False, **kwargs):
""" gMLP-Small
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
model_args = dict(
patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp, **kwargs)
model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args)
return model
@register_model
def gmlp_b16_224(pretrained=False, **kwargs):
""" gMLP-Base
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
"""
model_args = dict(
patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp, **kwargs)
model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args)
return model