From 55b135d5c6b9c22720559616e88f8c130973bedd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 22 Jul 2021 14:27:27 -0700 Subject: [PATCH] Add layer scale + affine option to perceiver --- tests/test_models.py | 2 +- timm/models/perceiver.py | 92 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 84 insertions(+), 10 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index cf7e6038..e275e430 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,7 +26,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', - '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*'] + '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'perceiver_l*'] else: EXCLUDE_FILTERS = [] diff --git a/timm/models/perceiver.py b/timm/models/perceiver.py index d8230f11..76a153e5 100644 --- a/timm/models/perceiver.py +++ b/timm/models/perceiver.py @@ -14,17 +14,15 @@ Status: Hacked together by / copyright Ross Wightman, 2021. """ import math -from collections import OrderedDict from functools import partial from typing import List, Tuple import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg, named_apply -from .layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_ntuple, ConvBnAct, LayerNorm2d +from .layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_ntuple from .registry import register_model @@ -47,6 +45,8 @@ default_cfgs = { url='', input_size=(3, 192, 192)), 'perceiver_m': _cfg( url=''), + 'perceiver_m_ls': _cfg( + url=''), 'perceiver_l': _cfg( url=''), } @@ -130,6 +130,16 @@ class CrossAttention(nn.Module): return out +class Affine(nn.Module): + def __init__(self, dim): + super().__init__() + self.alpha = nn.Parameter(torch.ones((1, 1, dim))) + self.beta = nn.Parameter(torch.zeros((1, 1, dim))) + + def forward(self, x): + return torch.addcmul(self.beta, self.alpha, x) + + @torch.jit.interface class CrossInterface(torch.nn.Module): def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor: @@ -159,6 +169,31 @@ class CrossBlock(nn.Module): return latent +class CrossBlockLayerScale(nn.Module): + + def __init__(self, latent_dim, data_dim, num_heads, attn_dim=None, mlp_ratio=4., qkv_bias=True, init_values=1e-5, + drop=0., drop_path=0., attn_layer=CrossAttention, act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1_latent = norm_layer(latent_dim) + self.norm1_data = norm_layer(data_dim) + self.attn = attn_layer( + latent_dim, data_dim, num_heads=num_heads, attn_dim=attn_dim, qkv_bias=qkv_bias, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(latent_dim) + mlp_hidden_dim = int(latent_dim * mlp_ratio) + self.mlp = Mlp(in_features=latent_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.ls1 = nn.Parameter(init_values * torch.ones(latent_dim)) + self.ls2 = nn.Parameter(init_values * torch.ones(latent_dim)) + + def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor: + latent = latent + self.drop_path(self.ls1 * self.attn( + self.norm1_latent(latent), + self.norm1_data(data), + )) + latent = latent + self.drop_path(self.ls2 * self.mlp(self.norm2(latent))) + return latent + + @torch.jit.interface class TransformerInterface(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -167,7 +202,7 @@ class TransformerInterface(torch.nn.Module): class TransformerBlock(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0., drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) @@ -182,15 +217,37 @@ class TransformerBlock(nn.Module): return x +class TransformerBlockLayerScale(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0., + drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.ls1 = nn.Parameter(init_values * torch.ones(dim)) + self.ls2 = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + x = x + self.drop_path(self.ls1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.ls2 * self.mlp(self.norm2(x))) + return x + + class TransformerStack(nn.Module): """ A stack-o-transformers NOTE this could have been a simple nn.Sequential but needed to wrap in module to use Interface def for ModuleDict torchscript compat. """ - def __init__(self, depth, dim, num_heads, mlp_ratio=4., **kwargs): + def __init__(self, depth, dim, num_heads, mlp_ratio=4., block=None, **kwargs): super().__init__() + block = block or TransformerBlock self.stack = nn.Sequential(*[ - TransformerBlock(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, **kwargs) for _ in range(depth)]) + block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, **kwargs) + for _ in range(depth) + ]) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.stack(x) @@ -235,7 +292,8 @@ class Perceiver(nn.Module): latent_dim=1024, num_latents=512, num_latent_heads=8, latent_mlp_ratio=1.0, cross_attn_dim=None, num_cross_heads=1, cross_mlp_ratio=1.0, share_weights=(1, 0), pos_embed_type='fourier', pos_embed_dim=128, data_bands=64, data_ndim=2, data_max_freq=10, - data_spatial=False, qkv_bias=True, cross_attn_layer=None, attn_layer=None, norm_layer=None, act_layer=None, + data_spatial=False, qkv_bias=True, cross_block=None, transformer_block=None, + cross_attn_layer=None, attn_layer=None, norm_layer=None, act_layer=None, drop_rate=0., drop_path_rate=0., weight_init=''): """ Args: @@ -263,6 +321,8 @@ class Perceiver(nn.Module): super().__init__() self.num_classes = num_classes self.num_features = self.latent_dim = latent_dim + cross_block = cross_block or CrossBlock + transformer_block = transformer_block or TransformerBlock cross_attn_layer = cross_attn_layer or CrossAttention attn_layer = attn_layer or Attention norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) @@ -282,13 +342,13 @@ class Perceiver(nn.Module): stage_args = dict( qkv_bias=qkv_bias, drop=drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, act_layer=act_layer) if k.startswith('c'): - self.blocks_cross[k] = CrossBlock( + self.blocks_cross[k] = cross_block( latent_dim=latent_dim, data_dim=self.data_dim, attn_dim=cross_attn_dim, num_heads=num_cross_heads, mlp_ratio=cross_mlp_ratio, attn_layer=cross_attn_layer, **stage_args) else: self.blocks_trans[k] = TransformerStack( depth=transformer_depth, dim=latent_dim, num_heads=num_latent_heads, - mlp_ratio=latent_mlp_ratio, attn_layer=attn_layer, **stage_args) + mlp_ratio=latent_mlp_ratio, attn_layer=attn_layer, block=transformer_block, **stage_args) self.norm = norm_layer(latent_dim) self.head = nn.Linear(latent_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -408,6 +468,20 @@ def perceiver_m(pretrained=False, **kwargs): return model +@register_model +def perceiver_m_ls(pretrained=False, **kwargs): + """ Perceiver-Medium w/ LayerScale + Affine + Two cross attn (one per each initial transformer stack), all transformers shared. ~25M params. + LayerScale + Affine influenced by CaiT, LeViT, ResMLP from Facebook AI + """ + model_kwargs = dict( + cross_depths=(1,) * 2, latent_dim=768, num_latents=384, cross_attn_dim=160, data_bands=40, + transformer_block=TransformerBlockLayerScale, cross_block=CrossBlockLayerScale, + norm_layer=Affine, **kwargs) + model = _create_perceiver('perceiver_m_ls', pretrained=pretrained, **model_kwargs) + return model + + @register_model def perceiver_l(pretrained=False, **kwargs): """ Perceiver-Large