Add layer scale + affine option to perceiver

perceiver
Ross Wightman 3 years ago
parent 820c262f33
commit 55b135d5c6

@ -26,7 +26,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
EXCLUDE_FILTERS = [ EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*'] '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'perceiver_l*']
else: else:
EXCLUDE_FILTERS = [] EXCLUDE_FILTERS = []

@ -14,17 +14,15 @@ Status:
Hacked together by / copyright Ross Wightman, 2021. Hacked together by / copyright Ross Wightman, 2021.
""" """
import math import math
from collections import OrderedDict
from functools import partial from functools import partial
from typing import List, Tuple from typing import List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_ntuple, ConvBnAct, LayerNorm2d from .layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_ntuple
from .registry import register_model from .registry import register_model
@ -47,6 +45,8 @@ default_cfgs = {
url='', input_size=(3, 192, 192)), url='', input_size=(3, 192, 192)),
'perceiver_m': _cfg( 'perceiver_m': _cfg(
url=''), url=''),
'perceiver_m_ls': _cfg(
url=''),
'perceiver_l': _cfg( 'perceiver_l': _cfg(
url=''), url=''),
} }
@ -130,6 +130,16 @@ class CrossAttention(nn.Module):
return out return out
class Affine(nn.Module):
def __init__(self, dim):
super().__init__()
self.alpha = nn.Parameter(torch.ones((1, 1, dim)))
self.beta = nn.Parameter(torch.zeros((1, 1, dim)))
def forward(self, x):
return torch.addcmul(self.beta, self.alpha, x)
@torch.jit.interface @torch.jit.interface
class CrossInterface(torch.nn.Module): class CrossInterface(torch.nn.Module):
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor: def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
@ -159,6 +169,31 @@ class CrossBlock(nn.Module):
return latent return latent
class CrossBlockLayerScale(nn.Module):
def __init__(self, latent_dim, data_dim, num_heads, attn_dim=None, mlp_ratio=4., qkv_bias=True, init_values=1e-5,
drop=0., drop_path=0., attn_layer=CrossAttention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1_latent = norm_layer(latent_dim)
self.norm1_data = norm_layer(data_dim)
self.attn = attn_layer(
latent_dim, data_dim, num_heads=num_heads, attn_dim=attn_dim, qkv_bias=qkv_bias, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(latent_dim)
mlp_hidden_dim = int(latent_dim * mlp_ratio)
self.mlp = Mlp(in_features=latent_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.ls1 = nn.Parameter(init_values * torch.ones(latent_dim))
self.ls2 = nn.Parameter(init_values * torch.ones(latent_dim))
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
latent = latent + self.drop_path(self.ls1 * self.attn(
self.norm1_latent(latent),
self.norm1_data(data),
))
latent = latent + self.drop_path(self.ls2 * self.mlp(self.norm2(latent)))
return latent
@torch.jit.interface @torch.jit.interface
class TransformerInterface(torch.nn.Module): class TransformerInterface(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -167,7 +202,7 @@ class TransformerInterface(torch.nn.Module):
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0.,
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm): drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
@ -182,15 +217,37 @@ class TransformerBlock(nn.Module):
return x return x
class TransformerBlockLayerScale(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0.,
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.ls1 = nn.Parameter(init_values * torch.ones(dim))
self.ls2 = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
x = x + self.drop_path(self.ls1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.ls2 * self.mlp(self.norm2(x)))
return x
class TransformerStack(nn.Module): class TransformerStack(nn.Module):
""" A stack-o-transformers """ A stack-o-transformers
NOTE this could have been a simple nn.Sequential but needed to wrap in module to use Interface NOTE this could have been a simple nn.Sequential but needed to wrap in module to use Interface
def for ModuleDict torchscript compat. def for ModuleDict torchscript compat.
""" """
def __init__(self, depth, dim, num_heads, mlp_ratio=4., **kwargs): def __init__(self, depth, dim, num_heads, mlp_ratio=4., block=None, **kwargs):
super().__init__() super().__init__()
block = block or TransformerBlock
self.stack = nn.Sequential(*[ self.stack = nn.Sequential(*[
TransformerBlock(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, **kwargs) for _ in range(depth)]) block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, **kwargs)
for _ in range(depth)
])
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.stack(x) return self.stack(x)
@ -235,7 +292,8 @@ class Perceiver(nn.Module):
latent_dim=1024, num_latents=512, num_latent_heads=8, latent_mlp_ratio=1.0, latent_dim=1024, num_latents=512, num_latent_heads=8, latent_mlp_ratio=1.0,
cross_attn_dim=None, num_cross_heads=1, cross_mlp_ratio=1.0, share_weights=(1, 0), cross_attn_dim=None, num_cross_heads=1, cross_mlp_ratio=1.0, share_weights=(1, 0),
pos_embed_type='fourier', pos_embed_dim=128, data_bands=64, data_ndim=2, data_max_freq=10, pos_embed_type='fourier', pos_embed_dim=128, data_bands=64, data_ndim=2, data_max_freq=10,
data_spatial=False, qkv_bias=True, cross_attn_layer=None, attn_layer=None, norm_layer=None, act_layer=None, data_spatial=False, qkv_bias=True, cross_block=None, transformer_block=None,
cross_attn_layer=None, attn_layer=None, norm_layer=None, act_layer=None,
drop_rate=0., drop_path_rate=0., weight_init=''): drop_rate=0., drop_path_rate=0., weight_init=''):
""" """
Args: Args:
@ -263,6 +321,8 @@ class Perceiver(nn.Module):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = self.latent_dim = latent_dim self.num_features = self.latent_dim = latent_dim
cross_block = cross_block or CrossBlock
transformer_block = transformer_block or TransformerBlock
cross_attn_layer = cross_attn_layer or CrossAttention cross_attn_layer = cross_attn_layer or CrossAttention
attn_layer = attn_layer or Attention attn_layer = attn_layer or Attention
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
@ -282,13 +342,13 @@ class Perceiver(nn.Module):
stage_args = dict( stage_args = dict(
qkv_bias=qkv_bias, drop=drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, act_layer=act_layer) qkv_bias=qkv_bias, drop=drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, act_layer=act_layer)
if k.startswith('c'): if k.startswith('c'):
self.blocks_cross[k] = CrossBlock( self.blocks_cross[k] = cross_block(
latent_dim=latent_dim, data_dim=self.data_dim, attn_dim=cross_attn_dim, num_heads=num_cross_heads, latent_dim=latent_dim, data_dim=self.data_dim, attn_dim=cross_attn_dim, num_heads=num_cross_heads,
mlp_ratio=cross_mlp_ratio, attn_layer=cross_attn_layer, **stage_args) mlp_ratio=cross_mlp_ratio, attn_layer=cross_attn_layer, **stage_args)
else: else:
self.blocks_trans[k] = TransformerStack( self.blocks_trans[k] = TransformerStack(
depth=transformer_depth, dim=latent_dim, num_heads=num_latent_heads, depth=transformer_depth, dim=latent_dim, num_heads=num_latent_heads,
mlp_ratio=latent_mlp_ratio, attn_layer=attn_layer, **stage_args) mlp_ratio=latent_mlp_ratio, attn_layer=attn_layer, block=transformer_block, **stage_args)
self.norm = norm_layer(latent_dim) self.norm = norm_layer(latent_dim)
self.head = nn.Linear(latent_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(latent_dim, num_classes) if num_classes > 0 else nn.Identity()
@ -408,6 +468,20 @@ def perceiver_m(pretrained=False, **kwargs):
return model return model
@register_model
def perceiver_m_ls(pretrained=False, **kwargs):
""" Perceiver-Medium w/ LayerScale + Affine
Two cross attn (one per each initial transformer stack), all transformers shared. ~25M params.
LayerScale + Affine influenced by CaiT, LeViT, ResMLP from Facebook AI
"""
model_kwargs = dict(
cross_depths=(1,) * 2, latent_dim=768, num_latents=384, cross_attn_dim=160, data_bands=40,
transformer_block=TransformerBlockLayerScale, cross_block=CrossBlockLayerScale,
norm_layer=Affine, **kwargs)
model = _create_perceiver('perceiver_m_ls', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def perceiver_l(pretrained=False, **kwargs): def perceiver_l(pretrained=False, **kwargs):
""" Perceiver-Large """ Perceiver-Large

Loading…
Cancel
Save