Add official pretrained weights to MLP-Mixer, complete model cfgs.

pull/613/head
Ross Wightman 3 years ago
parent 12efffa6b1
commit 2d8b09fe8b

@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
torch._C._jit_set_profiling_mode(False)
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*']
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'mixer_*']
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures

@ -1,10 +1,23 @@
""" MLP-Mixer in PyTorch
Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
NOTE this is a very early stage first run through, the param counts aren't matching paper so
something is up...
@article{tolstikhin2021,
title={MLP-Mixer: An all-MLP Architecture for Vision},
author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner,
Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
journal={arXiv preprint arXiv:2105.01601},
year={2021}
}
A thank you to paper authors for releasing code and weights.
Hacked together by / Copyright 2021 Ross Wightman
"""
import math
from copy import deepcopy
from functools import partial
import torch
@ -12,7 +25,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, lecun_normal_
from .layers import DropPath, to_2tuple, lecun_normal_
from .registry import register_model
@ -20,14 +33,39 @@ def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'stem.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = dict(
mixer_s32_224=_cfg(),
mixer_s16_224=_cfg(),
mixer_b32_224=_cfg(),
mixer_b16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
),
mixer_b16_224_in21k=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
num_classes=21843
),
mixer_l32_224=_cfg(),
mixer_l16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
),
mixer_l16_224_in21k=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
num_classes=21843
),
)
class Mlp(nn.Module):
""" MLP Block
NOTE: same impl as ViT, move to common location
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
@ -48,6 +86,7 @@ class Mlp(nn.Module):
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
NOTE: same impl as ViT, move to common location
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
@ -78,13 +117,13 @@ class MixerBlock(nn.Module):
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
super().__init__()
self.norm1 = norm_layer(dim)
self.mlp_token = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
self.mlp_tokens = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp_channels = Mlp(dim, channels_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.mlp_token(self.norm1(x).transpose(1, 2)).transpose(1, 2))
x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
return x
@ -105,6 +144,7 @@ class MlpMixer(nn.Module):
act_layer=nn.GELU,
drop=0.,
drop_path=0.,
nlhb=False,
):
super().__init__()
self.num_classes = num_classes
@ -116,9 +156,16 @@ class MlpMixer(nn.Module):
hidden_dim, self.stem.num_patches, tokens_dim, channels_dim,
norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path)
for _ in range(num_blocks)])
self.norm = nn.LayerNorm(hidden_dim)
self.norm = norm_layer(hidden_dim)
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
self.init_weights(nlhb=nlhb)
def init_weights(self, nlhb=False):
head_bias = -math.log(self.num_classes) if nlhb else 0.
for n, m in self.named_modules():
_init_weights(m, n, head_bias=head_bias)
def forward(self, x):
x = self.stem(x)
x = self.blocks(x)
@ -128,15 +175,118 @@ class MlpMixer(nn.Module):
return x
def _init_weights(m, n: str, head_bias: float = 0.):
""" Mixer weight initialization (trying to match Flax defaults)
"""
if isinstance(m, nn.Linear):
if n.startswith('head'):
nn.init.zeros_(m.weight)
nn.init.constant_(m.bias, head_bias)
else:
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
if 'mlp' in n:
nn.init.normal_(m.bias, std=1e-6)
else:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
lecun_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.zeros_(m.bias)
nn.init.ones_(m.weight)
def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
model = build_model_with_cfg(
MlpMixer, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
**kwargs)
return model
@register_model
def mixer_s32_224(pretrained=False, **kwargs):
""" Mixer-S/32 224x224
"""
model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs)
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_s16_224(pretrained=False, **kwargs):
""" Mixer-S/16 224x224
"""
model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=2048, **kwargs)
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_small_p16(pretrained=False, **kwargs):
model = MlpMixer()
model.default_cfg = _cfg()
def mixer_b32_224(pretrained=False, **kwargs):
""" Mixer-B/32 224x224
"""
model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_base_p16(pretrained=False, **kwargs):
model = MlpMixer(num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072)
model.default_cfg = _cfg()
def mixer_b16_224(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
"""
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224_in21k(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
"""
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072, **kwargs)
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_l32_224(pretrained=False, **kwargs):
""" Mixer-L/32 224x224.
"""
model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_l16_224(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights.
"""
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_l16_224_in21k(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights.
"""
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, tokens_dim=512, channels_dim=4096, **kwargs)
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
return model
Loading…
Cancel
Save