Move Mlp and PatchEmbed modules into layers. Being used in lots of models now...

pull/625/head
Ross Wightman 3 years ago
parent 3ba6b55cb2
commit b2c305c2aa

@ -15,8 +15,7 @@ from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import trunc_normal_, DropPath
from .vision_transformer import Mlp, PatchEmbed
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from .registry import register_model

@ -15,7 +15,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from functools import partial
@ -54,26 +54,6 @@ default_cfgs = {
}
class Mlp(nn.Module):
""" Feed-forward network (FFN, a.k.a. MLP) class. """
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ConvRelPosEnc(nn.Module):
""" Convolutional relative position encoding. """
def __init__(self, Ch, h, window):
@ -348,34 +328,6 @@ class ParallelBlock(nn.Module):
return x1, x2, x3, x4
class PatchEmbed(nn.Module):
""" Image to Patch Embedding """
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
f"img_size {img_size} should be divided by patch_size {patch_size}."
# Note: self.H, self.W and self.num_patches are not used
# since the image size may change on the fly.
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
_, _, H, W = x.shape
out_H, out_W = H // self.patch_size[0], W // self.patch_size[1]
x = self.proj(x).flatten(2).transpose(1, 2)
out = self.norm(x)
return out, (out_H, out_W)
class CoaT(nn.Module):
""" CoaT class. """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0],
@ -391,13 +343,17 @@ class CoaT(nn.Module):
# Patch embeddings.
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dims[0], norm_layer=nn.LayerNorm)
self.patch_embed2 = PatchEmbed(
img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1], norm_layer=nn.LayerNorm)
self.patch_embed3 = PatchEmbed(
img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2], norm_layer=nn.LayerNorm)
self.patch_embed4 = PatchEmbed(
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3], norm_layer=nn.LayerNorm)
# Class tokens.
self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0]))
@ -533,7 +489,8 @@ class CoaT(nn.Module):
B = x0.shape[0]
# Serial blocks 1.
x1, (H1, W1) = self.patch_embed1(x0)
x1 = self.patch_embed1(x0)
H1, W1 = self.patch_embed1.out_size
x1 = self.insert_cls(x1, self.cls_token1)
for blk in self.serial_blocks1:
x1 = blk(x1, size=(H1, W1))
@ -541,7 +498,8 @@ class CoaT(nn.Module):
x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
# Serial blocks 2.
x2, (H2, W2) = self.patch_embed2(x1_nocls)
x2 = self.patch_embed2(x1_nocls)
H2, W2 = self.patch_embed2.out_size
x2 = self.insert_cls(x2, self.cls_token2)
for blk in self.serial_blocks2:
x2 = blk(x2, size=(H2, W2))
@ -549,7 +507,8 @@ class CoaT(nn.Module):
x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
# Serial blocks 3.
x3, (H3, W3) = self.patch_embed3(x2_nocls)
x3 = self.patch_embed3(x2_nocls)
H3, W3 = self.patch_embed3.out_size
x3 = self.insert_cls(x3, self.cls_token3)
for blk in self.serial_blocks3:
x3 = blk(x3, size=(H3, W3))
@ -557,7 +516,8 @@ class CoaT(nn.Module):
x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
# Serial blocks 4.
x4, (H4, W4) = self.patch_embed4(x3_nocls)
x4 = self.patch_embed4(x3_nocls)
H4, W4 = self.patch_embed4.out_size
x4 = self.insert_cls(x4, self.cls_token4)
for blk in self.serial_blocks4:
x4 = blk(x4, size=(H4, W4))

@ -20,9 +20,11 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp
from .norm import GroupNorm
from .norm_act import BatchNormAct2d, GroupNormAct
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .se import SEModule
from .selective_kernel import SelectiveKernelConv

@ -0,0 +1,49 @@
""" MLP module w/ dropout and configurable activation layer
Hacked together by / Copyright 2020 Ross Wightman
"""
from torch import nn as nn
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class GluMlp(nn.Module):
""" MLP w/ GLU style gating
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features * 2)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x, gates = x.chunk(2, dim=-1)
x = x * self.act(gates)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x

@ -0,0 +1,36 @@
""" Image to Patch Embedding using Conv2d
A convolution based approach to patchifying a 2D image w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman
"""
from torch import nn as nn
from .helpers import to_2tuple
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.out_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.out_size[0] * self.out_size[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x

@ -25,7 +25,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import DropPath, to_2tuple, lecun_normal_
from .layers import PatchEmbed, Mlp, GluMlp, DropPath, lecun_normal_
from .registry import register_model
@ -43,6 +43,7 @@ def _cfg(url='', **kwargs):
default_cfgs = dict(
mixer_s32_224=_cfg(),
mixer_s16_224=_cfg(),
mixer_s16_glu_224=_cfg(),
mixer_b32_224=_cfg(),
mixer_b16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
@ -62,65 +63,17 @@ default_cfgs = dict(
)
class Mlp(nn.Module):
""" MLP Block
NOTE: same impl as ViT, move to common location
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
NOTE: same impl as ViT, move to common location
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.patch_grid = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.patch_grid[0] * self.patch_grid[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class MixerBlock(nn.Module):
def __init__(
self, dim, seq_len, tokens_dim, channels_dim,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
super().__init__()
self.norm1 = norm_layer(dim)
self.mlp_tokens = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp_channels = Mlp(dim, channels_dim, act_layer=act_layer, drop=drop)
self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
@ -140,6 +93,7 @@ class MlpMixer(nn.Module):
hidden_dim=512,
tokens_dim=256,
channels_dim=2048,
mlp_layer=Mlp,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
drop=0.,
@ -154,7 +108,7 @@ class MlpMixer(nn.Module):
self.blocks = nn.Sequential(*[
MixerBlock(
hidden_dim, self.stem.num_patches, tokens_dim, channels_dim,
norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path)
mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path)
for _ in range(num_blocks)])
self.norm = norm_layer(hidden_dim)
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init
@ -238,6 +192,17 @@ def mixer_s16_224(pretrained=False, **kwargs):
return model
@register_model
def mixer_s16_glu_224(pretrained=False, **kwargs):
""" Mixer-S/16 224x224
"""
model_args = dict(
patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=1536,
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
model = _create_mixer('mixer_s16_glu_224', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b32_224(pretrained=False, **kwargs):
""" Mixer-B/32 224x224

@ -49,6 +49,9 @@ default_cfgs = {
'resnet26d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth',
interpolation='bicubic', first_conv='conv1.0'),
'resnet26t': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)),
'resnet50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth',
interpolation='bicubic'),
@ -723,6 +726,15 @@ def resnet26(pretrained=False, **kwargs):
return _create_resnet('resnet26', pretrained, **model_args)
@register_model
def resnet26t(pretrained=False, **kwargs):
"""Constructs a ResNet-26-T model.
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
return _create_resnet('resnet26t', pretrained, **model_args)
@register_model
def resnet26d(pretrained=False, **kwargs):
"""Constructs a ResNet-26-D model.

@ -22,9 +22,9 @@ import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import DropPath, to_2tuple, trunc_normal_
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model
from .vision_transformer import checkpoint_filter_fn, Mlp, PatchEmbed, _init_vit_weights
from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
_logger = logging.getLogger(__name__)
@ -467,7 +467,7 @@ class SwinTransformer(nn.Module):
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
self.patch_grid = self.patch_embed.patch_grid
self.patch_grid = self.patch_embed.out_size
# absolute position embedding
if self.ape:

@ -13,8 +13,7 @@ from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.layers import DropPath, trunc_normal_
from timm.models.vision_transformer import Mlp
from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.registry import register_model

@ -29,7 +29,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, lecun_normal_
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from .registry import register_model
_logger = logging.getLogger(__name__)
@ -132,25 +132,6 @@ default_cfgs = {
}
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
@ -198,31 +179,6 @@ class Block(nn.Module):
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.patch_grid = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.patch_grid[0] * self.patch_grid[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class VisionTransformer(nn.Module):
""" Vision Transformer

Loading…
Cancel
Save