Add layer scale + affine option to perceiver

perceiver
Ross Wightman 3 years ago
parent 820c262f33
commit 55b135d5c6

@ -26,7 +26,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*']
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'perceiver_l*']
else:
EXCLUDE_FILTERS = []

@ -14,17 +14,15 @@ Status:
Hacked together by / copyright Ross Wightman, 2021.
"""
import math
from collections import OrderedDict
from functools import partial
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply
from .layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_ntuple, ConvBnAct, LayerNorm2d
from .layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_ntuple
from .registry import register_model
@ -47,6 +45,8 @@ default_cfgs = {
url='', input_size=(3, 192, 192)),
'perceiver_m': _cfg(
url=''),
'perceiver_m_ls': _cfg(
url=''),
'perceiver_l': _cfg(
url=''),
}
@ -130,6 +130,16 @@ class CrossAttention(nn.Module):
return out
class Affine(nn.Module):
def __init__(self, dim):
super().__init__()
self.alpha = nn.Parameter(torch.ones((1, 1, dim)))
self.beta = nn.Parameter(torch.zeros((1, 1, dim)))
def forward(self, x):
return torch.addcmul(self.beta, self.alpha, x)
@torch.jit.interface
class CrossInterface(torch.nn.Module):
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
@ -159,6 +169,31 @@ class CrossBlock(nn.Module):
return latent
class CrossBlockLayerScale(nn.Module):
def __init__(self, latent_dim, data_dim, num_heads, attn_dim=None, mlp_ratio=4., qkv_bias=True, init_values=1e-5,
drop=0., drop_path=0., attn_layer=CrossAttention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1_latent = norm_layer(latent_dim)
self.norm1_data = norm_layer(data_dim)
self.attn = attn_layer(
latent_dim, data_dim, num_heads=num_heads, attn_dim=attn_dim, qkv_bias=qkv_bias, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(latent_dim)
mlp_hidden_dim = int(latent_dim * mlp_ratio)
self.mlp = Mlp(in_features=latent_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.ls1 = nn.Parameter(init_values * torch.ones(latent_dim))
self.ls2 = nn.Parameter(init_values * torch.ones(latent_dim))
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
latent = latent + self.drop_path(self.ls1 * self.attn(
self.norm1_latent(latent),
self.norm1_data(data),
))
latent = latent + self.drop_path(self.ls2 * self.mlp(self.norm2(latent)))
return latent
@torch.jit.interface
class TransformerInterface(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -167,7 +202,7 @@ class TransformerInterface(torch.nn.Module):
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.,
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0.,
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
@ -182,15 +217,37 @@ class TransformerBlock(nn.Module):
return x
class TransformerBlockLayerScale(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0.,
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.ls1 = nn.Parameter(init_values * torch.ones(dim))
self.ls2 = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
x = x + self.drop_path(self.ls1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.ls2 * self.mlp(self.norm2(x)))
return x
class TransformerStack(nn.Module):
""" A stack-o-transformers
NOTE this could have been a simple nn.Sequential but needed to wrap in module to use Interface
def for ModuleDict torchscript compat.
"""
def __init__(self, depth, dim, num_heads, mlp_ratio=4., **kwargs):
def __init__(self, depth, dim, num_heads, mlp_ratio=4., block=None, **kwargs):
super().__init__()
block = block or TransformerBlock
self.stack = nn.Sequential(*[
TransformerBlock(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, **kwargs) for _ in range(depth)])
block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, **kwargs)
for _ in range(depth)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.stack(x)
@ -235,7 +292,8 @@ class Perceiver(nn.Module):
latent_dim=1024, num_latents=512, num_latent_heads=8, latent_mlp_ratio=1.0,
cross_attn_dim=None, num_cross_heads=1, cross_mlp_ratio=1.0, share_weights=(1, 0),
pos_embed_type='fourier', pos_embed_dim=128, data_bands=64, data_ndim=2, data_max_freq=10,
data_spatial=False, qkv_bias=True, cross_attn_layer=None, attn_layer=None, norm_layer=None, act_layer=None,
data_spatial=False, qkv_bias=True, cross_block=None, transformer_block=None,
cross_attn_layer=None, attn_layer=None, norm_layer=None, act_layer=None,
drop_rate=0., drop_path_rate=0., weight_init=''):
"""
Args:
@ -263,6 +321,8 @@ class Perceiver(nn.Module):
super().__init__()
self.num_classes = num_classes
self.num_features = self.latent_dim = latent_dim
cross_block = cross_block or CrossBlock
transformer_block = transformer_block or TransformerBlock
cross_attn_layer = cross_attn_layer or CrossAttention
attn_layer = attn_layer or Attention
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
@ -282,13 +342,13 @@ class Perceiver(nn.Module):
stage_args = dict(
qkv_bias=qkv_bias, drop=drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, act_layer=act_layer)
if k.startswith('c'):
self.blocks_cross[k] = CrossBlock(
self.blocks_cross[k] = cross_block(
latent_dim=latent_dim, data_dim=self.data_dim, attn_dim=cross_attn_dim, num_heads=num_cross_heads,
mlp_ratio=cross_mlp_ratio, attn_layer=cross_attn_layer, **stage_args)
else:
self.blocks_trans[k] = TransformerStack(
depth=transformer_depth, dim=latent_dim, num_heads=num_latent_heads,
mlp_ratio=latent_mlp_ratio, attn_layer=attn_layer, **stage_args)
mlp_ratio=latent_mlp_ratio, attn_layer=attn_layer, block=transformer_block, **stage_args)
self.norm = norm_layer(latent_dim)
self.head = nn.Linear(latent_dim, num_classes) if num_classes > 0 else nn.Identity()
@ -408,6 +468,20 @@ def perceiver_m(pretrained=False, **kwargs):
return model
@register_model
def perceiver_m_ls(pretrained=False, **kwargs):
""" Perceiver-Medium w/ LayerScale + Affine
Two cross attn (one per each initial transformer stack), all transformers shared. ~25M params.
LayerScale + Affine influenced by CaiT, LeViT, ResMLP from Facebook AI
"""
model_kwargs = dict(
cross_depths=(1,) * 2, latent_dim=768, num_latents=384, cross_attn_dim=160, data_bands=40,
transformer_block=TransformerBlockLayerScale, cross_block=CrossBlockLayerScale,
norm_layer=Affine, **kwargs)
model = _create_perceiver('perceiver_m_ls', pretrained=pretrained, **model_kwargs)
return model
@register_model
def perceiver_l(pretrained=False, **kwargs):
""" Perceiver-Large

Loading…
Cancel
Save