From b2c305c2aa090ebc44f5836737dc3e28b413d43b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 6 May 2021 14:03:23 -0700 Subject: [PATCH] Move Mlp and PatchEmbed modules into layers. Being used in lots of models now... --- timm/models/cait.py | 3 +- timm/models/coat.py | 74 +++++++------------------------ timm/models/layers/__init__.py | 2 + timm/models/layers/mlp.py | 49 ++++++++++++++++++++ timm/models/layers/patch_embed.py | 36 +++++++++++++++ timm/models/mlp_mixer.py | 71 ++++++++--------------------- timm/models/resnet.py | 12 +++++ timm/models/swin_transformer.py | 6 +-- timm/models/tnt.py | 3 +- timm/models/vision_transformer.py | 46 +------------------ 10 files changed, 140 insertions(+), 162 deletions(-) create mode 100644 timm/models/layers/mlp.py create mode 100644 timm/models/layers/patch_embed.py diff --git a/timm/models/cait.py b/timm/models/cait.py index c16bf86a..b648e712 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -15,8 +15,7 @@ from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import trunc_normal_, DropPath -from .vision_transformer import Mlp, PatchEmbed +from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from .registry import register_model diff --git a/timm/models/coat.py b/timm/models/coat.py index 7b364dae..38bc93a3 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -15,7 +15,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import load_pretrained -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from timm.models.registry import register_model from functools import partial @@ -54,26 +54,6 @@ default_cfgs = { } -class Mlp(nn.Module): - """ Feed-forward network (FFN, a.k.a. MLP) class. """ - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - class ConvRelPosEnc(nn.Module): """ Convolutional relative position encoding. """ def __init__(self, Ch, h, window): @@ -348,34 +328,6 @@ class ParallelBlock(nn.Module): return x1, x2, x3, x4 -class PatchEmbed(nn.Module): - """ Image to Patch Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - - self.img_size = img_size - self.patch_size = patch_size - assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ - f"img_size {img_size} should be divided by patch_size {patch_size}." - # Note: self.H, self.W and self.num_patches are not used - # since the image size may change on the fly. - self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] - self.num_patches = self.H * self.W - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - self.norm = nn.LayerNorm(embed_dim) - - def forward(self, x): - _, _, H, W = x.shape - out_H, out_W = H // self.patch_size[0], W // self.patch_size[1] - - x = self.proj(x).flatten(2).transpose(1, 2) - out = self.norm(x) - - return out, (out_H, out_W) - - class CoaT(nn.Module): """ CoaT class. """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0], @@ -391,13 +343,17 @@ class CoaT(nn.Module): # Patch embeddings. self.patch_embed1 = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dims[0], norm_layer=nn.LayerNorm) self.patch_embed2 = PatchEmbed( - img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) + img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], + embed_dim=embed_dims[1], norm_layer=nn.LayerNorm) self.patch_embed3 = PatchEmbed( - img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) + img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], + embed_dim=embed_dims[2], norm_layer=nn.LayerNorm) self.patch_embed4 = PatchEmbed( - img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) + img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], + embed_dim=embed_dims[3], norm_layer=nn.LayerNorm) # Class tokens. self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0])) @@ -533,7 +489,8 @@ class CoaT(nn.Module): B = x0.shape[0] # Serial blocks 1. - x1, (H1, W1) = self.patch_embed1(x0) + x1 = self.patch_embed1(x0) + H1, W1 = self.patch_embed1.out_size x1 = self.insert_cls(x1, self.cls_token1) for blk in self.serial_blocks1: x1 = blk(x1, size=(H1, W1)) @@ -541,7 +498,8 @@ class CoaT(nn.Module): x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() # Serial blocks 2. - x2, (H2, W2) = self.patch_embed2(x1_nocls) + x2 = self.patch_embed2(x1_nocls) + H2, W2 = self.patch_embed2.out_size x2 = self.insert_cls(x2, self.cls_token2) for blk in self.serial_blocks2: x2 = blk(x2, size=(H2, W2)) @@ -549,7 +507,8 @@ class CoaT(nn.Module): x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous() # Serial blocks 3. - x3, (H3, W3) = self.patch_embed3(x2_nocls) + x3 = self.patch_embed3(x2_nocls) + H3, W3 = self.patch_embed3.out_size x3 = self.insert_cls(x3, self.cls_token3) for blk in self.serial_blocks3: x3 = blk(x3, size=(H3, W3)) @@ -557,7 +516,8 @@ class CoaT(nn.Module): x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous() # Serial blocks 4. - x4, (H4, W4) = self.patch_embed4(x3_nocls) + x4 = self.patch_embed4(x3_nocls) + H4, W4 = self.patch_embed4.out_size x4 = self.insert_cls(x4, self.cls_token4) for blk in self.serial_blocks4: x4 = blk(x4, size=(H4, W4)) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index eecbbde4..90241f5c 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -20,9 +20,11 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible from .inplace_abn import InplaceAbn from .linear import Linear from .mixed_conv2d import MixedConv2d +from .mlp import Mlp, GluMlp from .norm import GroupNorm from .norm_act import BatchNormAct2d, GroupNormAct from .padding import get_padding, get_same_padding, pad_same +from .patch_embed import PatchEmbed from .pool2d_same import AvgPool2dSame, create_pool2d from .se import SEModule from .selective_kernel import SelectiveKernelConv diff --git a/timm/models/layers/mlp.py b/timm/models/layers/mlp.py new file mode 100644 index 00000000..b65c8d07 --- /dev/null +++ b/timm/models/layers/mlp.py @@ -0,0 +1,49 @@ +""" MLP module w/ dropout and configurable activation layer + +Hacked together by / Copyright 2020 Ross Wightman +""" +from torch import nn as nn + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class GluMlp(nn.Module): + """ MLP w/ GLU style gating + See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features * 2) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x, gates = x.chunk(2, dim=-1) + x = x * self.act(gates) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py new file mode 100644 index 00000000..f7a07e18 --- /dev/null +++ b/timm/models/layers/patch_embed.py @@ -0,0 +1,36 @@ +""" Image to Patch Embedding using Conv2d + +A convolution based approach to patchifying a 2D image w/ embedding projection. + +Based on the impl in https://github.com/google-research/vision_transformer + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from torch import nn as nn + +from .helpers import to_2tuple + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.out_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.out_size[0] * self.out_size[1] + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + return x diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index e044e961..c2c96e6c 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -25,7 +25,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import DropPath, to_2tuple, lecun_normal_ +from .layers import PatchEmbed, Mlp, GluMlp, DropPath, lecun_normal_ from .registry import register_model @@ -43,6 +43,7 @@ def _cfg(url='', **kwargs): default_cfgs = dict( mixer_s32_224=_cfg(), mixer_s16_224=_cfg(), + mixer_s16_glu_224=_cfg(), mixer_b32_224=_cfg(), mixer_b16_224=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth', @@ -62,65 +63,17 @@ default_cfgs = dict( ) -class Mlp(nn.Module): - """ MLP Block - NOTE: same impl as ViT, move to common location - """ - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class PatchEmbed(nn.Module): - """ Image to Patch Embedding - NOTE: same impl as ViT, move to common location - """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.patch_grid = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.patch_grid[0] * self.patch_grid[1] - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) - x = self.norm(x) - return x - - class MixerBlock(nn.Module): def __init__( self, dim, seq_len, tokens_dim, channels_dim, - norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): + mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.): super().__init__() self.norm1 = norm_layer(dim) - self.mlp_tokens = Mlp(seq_len, tokens_dim, act_layer=act_layer, drop=drop) + self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - self.mlp_channels = Mlp(dim, channels_dim, act_layer=act_layer, drop=drop) + self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) def forward(self, x): x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) @@ -140,6 +93,7 @@ class MlpMixer(nn.Module): hidden_dim=512, tokens_dim=256, channels_dim=2048, + mlp_layer=Mlp, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., @@ -154,7 +108,7 @@ class MlpMixer(nn.Module): self.blocks = nn.Sequential(*[ MixerBlock( hidden_dim, self.stem.num_patches, tokens_dim, channels_dim, - norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path) + mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop, drop_path=drop_path) for _ in range(num_blocks)]) self.norm = norm_layer(hidden_dim) self.head = nn.Linear(hidden_dim, self.num_classes) # zero init @@ -238,6 +192,17 @@ def mixer_s16_224(pretrained=False, **kwargs): return model +@register_model +def mixer_s16_glu_224(pretrained=False, **kwargs): + """ Mixer-S/16 224x224 + """ + model_args = dict( + patch_size=16, num_blocks=8, hidden_dim=512, tokens_dim=256, channels_dim=1536, + mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) + model = _create_mixer('mixer_s16_glu_224', pretrained=pretrained, **model_args) + return model + + @register_model def mixer_b32_224(pretrained=False, **kwargs): """ Mixer-B/32 224x224 diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 491d9acb..2b0b0339 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -49,6 +49,9 @@ default_cfgs = { 'resnet26d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth', interpolation='bicubic', first_conv='conv1.0'), + 'resnet26t': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)), 'resnet50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', interpolation='bicubic'), @@ -723,6 +726,15 @@ def resnet26(pretrained=False, **kwargs): return _create_resnet('resnet26', pretrained, **model_args) +@register_model +def resnet26t(pretrained=False, **kwargs): + """Constructs a ResNet-26-T model. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) + return _create_resnet('resnet26t', pretrained, **model_args) + + @register_model def resnet26d(pretrained=False, **kwargs): """Constructs a ResNet-26-D model. diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index a3fd3de7..2880aa02 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -22,9 +22,9 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import DropPath, to_2tuple, trunc_normal_ +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ from .registry import register_model -from .vision_transformer import checkpoint_filter_fn, Mlp, PatchEmbed, _init_vit_weights +from .vision_transformer import checkpoint_filter_fn, _init_vit_weights _logger = logging.getLogger(__name__) @@ -467,7 +467,7 @@ class SwinTransformer(nn.Module): img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches - self.patch_grid = self.patch_embed.patch_grid + self.patch_grid = self.patch_embed.out_size # absolute position embedding if self.ape: diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 42c03e61..cc732677 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -13,8 +13,7 @@ from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import load_pretrained -from timm.models.layers import DropPath, trunc_normal_ -from timm.models.vision_transformer import Mlp +from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.registry import register_model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 4bf1dec5..cc7e0903 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -29,7 +29,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import DropPath, to_2tuple, trunc_normal_, lecun_normal_ +from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .registry import register_model _logger = logging.getLogger(__name__) @@ -132,25 +132,6 @@ default_cfgs = { } -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() @@ -198,31 +179,6 @@ class Block(nn.Module): return x -class PatchEmbed(nn.Module): - """ Image to Patch Embedding - """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.patch_grid = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.patch_grid[0] * self.patch_grid[1] - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) - x = self.norm(x) - return x - - class VisionTransformer(nn.Module): """ Vision Transformer