|
|
|
@ -14,8 +14,9 @@ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2
|
|
|
|
|
year={2021}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Also supporting preliminary (not verified) implementations of ResMlp, gMLP, and possibly more...
|
|
|
|
|
Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP
|
|
|
|
|
|
|
|
|
|
Code: https://github.com/facebookresearch/deit
|
|
|
|
|
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
|
|
|
|
@misc{touvron2021resmlp,
|
|
|
|
|
title={ResMLP: Feedforward networks for image classification with data-efficient training},
|
|
|
|
@ -94,11 +95,36 @@ default_cfgs = dict(
|
|
|
|
|
gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
gmixer_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
|
|
|
|
|
resmlp_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
resmlp_12_224=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
resmlp_24_224=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=0.89),
|
|
|
|
|
resmlp_36_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth',
|
|
|
|
|
#url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
resmlp_36_224=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
resmlp_big_24_224=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
|
|
|
|
|
resmlp_12_distilled_224=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
resmlp_24_distilled_224=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
resmlp_36_distilled_224=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
resmlp_big_24_distilled_224=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
|
|
|
|
|
resmlp_big_24_224_in22ft1k=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
|
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
|
|
|
|
|
|
|
|
|
gmlp_ti16_224=_cfg(),
|
|
|
|
|
gmlp_s16_224=_cfg(),
|
|
|
|
@ -266,7 +292,7 @@ class MlpMixer(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(module: nn.Module, name: str, head_bias: float = 0.):
|
|
|
|
|
def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
|
|
|
|
|
""" Mixer weight initialization (trying to match Flax defaults)
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(module, nn.Linear):
|
|
|
|
@ -274,6 +300,13 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0.):
|
|
|
|
|
nn.init.zeros_(module.weight)
|
|
|
|
|
nn.init.constant_(module.bias, head_bias)
|
|
|
|
|
else:
|
|
|
|
|
if flax:
|
|
|
|
|
# Flax defaults
|
|
|
|
|
lecun_normal_(module.weight)
|
|
|
|
|
if module.bias is not None:
|
|
|
|
|
nn.init.zeros_(module.bias)
|
|
|
|
|
else:
|
|
|
|
|
# like MLP init in vit (my original init)
|
|
|
|
|
nn.init.xavier_uniform_(module.weight)
|
|
|
|
|
if module.bias is not None:
|
|
|
|
|
if 'mlp' in name:
|
|
|
|
@ -293,6 +326,23 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0.):
|
|
|
|
|
module.init_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
""" Remap checkpoints if needed """
|
|
|
|
|
if 'patch_embed.proj.weight' in state_dict:
|
|
|
|
|
# Remap FB ResMlp models -> timm
|
|
|
|
|
out_dict = {}
|
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
|
k = k.replace('patch_embed.', 'stem.')
|
|
|
|
|
k = k.replace('attn.', 'linear_tokens.')
|
|
|
|
|
k = k.replace('mlp.', 'mlp_channels.')
|
|
|
|
|
k = k.replace('gamma_', 'ls')
|
|
|
|
|
if k.endswith('.alpha') or k.endswith('.beta'):
|
|
|
|
|
v = v.reshape(1, 1, -1)
|
|
|
|
|
out_dict[k] = v
|
|
|
|
|
return out_dict
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_mixer(variant, pretrained=False, **kwargs):
|
|
|
|
|
if kwargs.get('features_only', None):
|
|
|
|
|
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
|
|
|
|
@ -300,6 +350,7 @@ def _create_mixer(variant, pretrained=False, **kwargs):
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
MlpMixer, variant, pretrained,
|
|
|
|
|
default_cfg=default_cfgs[variant],
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|
**kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
@ -458,11 +509,82 @@ def resmlp_36_224(pretrained=False, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
|
|
|
|
|
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
|
|
|
|
|
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
|
|
|
|
model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resmlp_big_24_224(pretrained=False, **kwargs):
|
|
|
|
|
""" ResMLP-B-24
|
|
|
|
|
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
|
|
|
|
|
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
|
|
|
|
model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resmlp_12_distilled_224(pretrained=False, **kwargs):
|
|
|
|
|
""" ResMLP-12
|
|
|
|
|
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
|
|
|
|
|
model = _create_mixer('resmlp_12_distilled_224', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resmlp_24_distilled_224(pretrained=False, **kwargs):
|
|
|
|
|
""" ResMLP-24
|
|
|
|
|
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
|
|
|
|
|
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
|
|
|
|
|
model = _create_mixer('resmlp_24_distilled_224', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resmlp_36_distilled_224(pretrained=False, **kwargs):
|
|
|
|
|
""" ResMLP-36
|
|
|
|
|
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
|
|
|
|
|
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
|
|
|
|
model = _create_mixer('resmlp_36_distilled_224', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resmlp_big_24_distilled_224(pretrained=False, **kwargs):
|
|
|
|
|
""" ResMLP-B-24
|
|
|
|
|
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
|
|
|
|
|
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
|
|
|
|
model = _create_mixer('resmlp_big_24_distilled_224', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs):
|
|
|
|
|
""" ResMLP-B-24
|
|
|
|
|
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
|
|
|
|
|
"""
|
|
|
|
|
model_args = dict(
|
|
|
|
|
patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
|
|
|
|
|
block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
|
|
|
|
|
model = _create_mixer('resmlp_big_24_224_in22ft1k', pretrained=pretrained, **model_args)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def gmlp_ti16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" gMLP-Tiny
|
|
|
|
|