Initial MLP-Mixer attempt...

pull/613/head
Ross Wightman 4 years ago
parent 0721559511
commit 12efffa6b1

@ -14,6 +14,7 @@ from .hrnet import *
from .inception_resnet_v2 import * from .inception_resnet_v2 import *
from .inception_v3 import * from .inception_v3 import *
from .inception_v4 import * from .inception_v4 import *
from .mlp_mixer import *
from .mobilenetv3 import * from .mobilenetv3 import *
from .nasnet import * from .nasnet import *
from .nfnet import * from .nfnet import *

@ -0,0 +1,142 @@
""" MLP-Mixer in PyTorch
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
NOTE this is a very early stage first run through, the param counts aren't matching paper so
something is up...
"""
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, lecun_normal_
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.patch_grid = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.patch_grid[0] * self.patch_grid[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class MixerBlock(nn.Module):
def __init__(
self, dim, seq_len, tokens_dim, channels_dim,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
super().__init__()
self.norm1 = norm_layer(dim)
self.mlp_token = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp_channels = Mlp(dim, channels_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.mlp_token(self.norm1(x).transpose(1, 2)).transpose(1, 2))
x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
return x
class MlpMixer(nn.Module):
def __init__(
self,
num_classes=1000,
img_size=224,
in_chans=3,
patch_size=16,
num_blocks=8,
hidden_dim=512,
tokens_dim=256,
channels_dim=2048,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
drop=0.,
drop_path=0.,
):
super().__init__()
self.num_classes = num_classes
self.stem = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim)
# FIXME drop_path (stochastic depth scaling rule?)
self.blocks = nn.Sequential(*[
MixerBlock(
hidden_dim, self.stem.num_patches, tokens_dim, channels_dim,
norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path)
for _ in range(num_blocks)])
self.norm = nn.LayerNorm(hidden_dim)
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
def forward(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.norm(x)
x = x.mean(dim=1)
x = self.head(x)
return x
@register_model
def mixer_small_p16(pretrained=False, **kwargs):
model = MlpMixer()
model.default_cfg = _cfg()
return model
@register_model
def mixer_base_p16(pretrained=False, **kwargs):
model = MlpMixer(num_blocks=12, hidden_dim=768, tokens_dim=384, channels_dim=3072)
model.default_cfg = _cfg()
return model
Loading…
Cancel
Save