Compare commits
3 Commits
Author | SHA1 | Date |
---|---|---|
Ross Wightman | 55b135d5c6 | 3 years ago |
Ross Wightman | 820c262f33 | 3 years ago |
Ross Wightman | 77698d80a5 | 3 years ago |
@ -0,0 +1,493 @@
|
||||
""" Perceiver
|
||||
|
||||
Paper: `Perceiver: General Perception with Iterative Attention` - https://arxiv.org/abs/2103.03206
|
||||
|
||||
Official Deepmind code: TBD (doesn't exist yet)
|
||||
|
||||
Fourier feature position embedding references:
|
||||
* Official NeRF impl - https://github.com/bmild/nerf
|
||||
* Lucidrain's Perceiver impl - https://github.com/lucidrains/perceiver-pytorch
|
||||
|
||||
Status:
|
||||
* Work in progress, currently running training trials with S and M models (rather slow)
|
||||
|
||||
Hacked together by / copyright Ross Wightman, 2021.
|
||||
"""
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import build_model_with_cfg, named_apply
|
||||
from .layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_ntuple
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': None, 'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# patch models (weights from official Google JAX impl)
|
||||
'perceiver_ss': _cfg(
|
||||
url='', input_size=(3, 192, 192)),
|
||||
'perceiver_s': _cfg(
|
||||
url='', input_size=(3, 192, 192)),
|
||||
'perceiver_m': _cfg(
|
||||
url=''),
|
||||
'perceiver_m_ls': _cfg(
|
||||
url=''),
|
||||
'perceiver_l': _cfg(
|
||||
url=''),
|
||||
}
|
||||
|
||||
|
||||
def fourier_encode(x, max_freq_log2: int = 8, num_bands: int = 64):
|
||||
""" Fourier feature embedding.
|
||||
Referenced official NeRF code and Lucidrain's PyTorch Perceiver impl.
|
||||
"""
|
||||
# FIXME this will likely need to change once official code / weights are available
|
||||
x = x.unsqueeze(-1)
|
||||
bands = 2 ** torch.linspace(0, max_freq_log2 - 1, num_bands, device=x.device, dtype=x.dtype)
|
||||
x_bands = x * math.pi * bands
|
||||
x = torch.cat([x, x_bands.sin(), x_bands.cos()], dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
def fourier_grid(
|
||||
shape: List[int], max_freq_log2: int = 8, num_bands: int = 64, device: torch.device = torch.device('cuda')):
|
||||
grid = torch.stack(torch.meshgrid([torch.linspace(-1., 1., steps=s, device=device) for s in shape]), dim=-1)
|
||||
enc_pos = fourier_encode(grid, max_freq_log2, num_bands)
|
||||
return enc_pos.transpose(-1, -2).flatten(len(shape))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, proj_drop=0.):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
"""
|
||||
"""
|
||||
def __init__(self, latent_dim, data_dim, attn_dim=None, num_heads=1, qkv_bias=True, proj_drop=0.):
|
||||
super().__init__()
|
||||
assert latent_dim % num_heads == 0, f"dim {latent_dim} should be divided by num_heads {num_heads}."
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.attn_dim = attn_dim or min(latent_dim, data_dim)
|
||||
self.num_heads = num_heads
|
||||
head_dim = self.attn_dim // num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.q = nn.Linear(latent_dim, self.attn_dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(data_dim, self.attn_dim * 2, bias=qkv_bias)
|
||||
self.proj = nn.Linear(self.attn_dim, latent_dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, latent, data):
|
||||
B = latent.shape[0]
|
||||
q = self.q(latent).reshape(B, -1, self.num_heads, self.attn_dim // self.num_heads).permute(0, 2, 1, 3)
|
||||
|
||||
kv = self.kv(data).reshape(B, -1, 2, self.num_heads, self.attn_dim // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
out = (attn @ v).transpose(1, 2).reshape(B, -1, self.attn_dim)
|
||||
out = self.proj(out)
|
||||
out = self.proj_drop(out)
|
||||
return out
|
||||
|
||||
|
||||
class Affine(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones((1, 1, dim)))
|
||||
self.beta = nn.Parameter(torch.zeros((1, 1, dim)))
|
||||
|
||||
def forward(self, x):
|
||||
return torch.addcmul(self.beta, self.alpha, x)
|
||||
|
||||
|
||||
@torch.jit.interface
|
||||
class CrossInterface(torch.nn.Module):
|
||||
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class CrossBlock(nn.Module):
|
||||
|
||||
def __init__(self, latent_dim, data_dim, num_heads, attn_dim=None, mlp_ratio=4., qkv_bias=True,
|
||||
drop=0., drop_path=0., attn_layer=CrossAttention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1_latent = norm_layer(latent_dim)
|
||||
self.norm1_data = norm_layer(data_dim)
|
||||
self.attn = attn_layer(
|
||||
latent_dim, data_dim, num_heads=num_heads, attn_dim=attn_dim, qkv_bias=qkv_bias, proj_drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(latent_dim)
|
||||
mlp_hidden_dim = int(latent_dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=latent_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
|
||||
latent = latent + self.drop_path(self.attn(
|
||||
self.norm1_latent(latent),
|
||||
self.norm1_data(data),
|
||||
))
|
||||
latent = latent + self.drop_path(self.mlp(self.norm2(latent)))
|
||||
return latent
|
||||
|
||||
|
||||
class CrossBlockLayerScale(nn.Module):
|
||||
|
||||
def __init__(self, latent_dim, data_dim, num_heads, attn_dim=None, mlp_ratio=4., qkv_bias=True, init_values=1e-5,
|
||||
drop=0., drop_path=0., attn_layer=CrossAttention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1_latent = norm_layer(latent_dim)
|
||||
self.norm1_data = norm_layer(data_dim)
|
||||
self.attn = attn_layer(
|
||||
latent_dim, data_dim, num_heads=num_heads, attn_dim=attn_dim, qkv_bias=qkv_bias, proj_drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(latent_dim)
|
||||
mlp_hidden_dim = int(latent_dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=latent_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.ls1 = nn.Parameter(init_values * torch.ones(latent_dim))
|
||||
self.ls2 = nn.Parameter(init_values * torch.ones(latent_dim))
|
||||
|
||||
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
|
||||
latent = latent + self.drop_path(self.ls1 * self.attn(
|
||||
self.norm1_latent(latent),
|
||||
self.norm1_data(data),
|
||||
))
|
||||
latent = latent + self.drop_path(self.ls2 * self.mlp(self.norm2(latent)))
|
||||
return latent
|
||||
|
||||
|
||||
@torch.jit.interface
|
||||
class TransformerInterface(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0.,
|
||||
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class TransformerBlockLayerScale(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0.,
|
||||
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||
self.ls1 = nn.Parameter(init_values * torch.ones(dim))
|
||||
self.ls2 = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.ls1 * self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.ls2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class TransformerStack(nn.Module):
|
||||
""" A stack-o-transformers
|
||||
NOTE this could have been a simple nn.Sequential but needed to wrap in module to use Interface
|
||||
def for ModuleDict torchscript compat.
|
||||
"""
|
||||
def __init__(self, depth, dim, num_heads, mlp_ratio=4., block=None, **kwargs):
|
||||
super().__init__()
|
||||
block = block or TransformerBlock
|
||||
self.stack = nn.Sequential(*[
|
||||
block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, **kwargs)
|
||||
for _ in range(depth)
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.stack(x)
|
||||
|
||||
|
||||
def get_layer_layout(cross_depths, num_stages=8, share_weights=None):
|
||||
if isinstance(cross_depths, (tuple, list)):
|
||||
stage_cross_depths = tuple(cross_depths)
|
||||
stage_cross_depths = (stage_cross_depths + (0,) * num_stages)[:num_stages]
|
||||
else:
|
||||
stage_cross_depths = to_ntuple(num_stages)(cross_depths)
|
||||
prev_cross_key = ''
|
||||
prev_transformer_key = ''
|
||||
keys = []
|
||||
num_cross = 0
|
||||
num_transformer = 0
|
||||
for i, cd in enumerate(stage_cross_depths):
|
||||
for j in range(cd):
|
||||
key = prev_cross_key
|
||||
if share_weights is None or num_cross <= share_weights[0]:
|
||||
key = f'c{i}_{j}'
|
||||
keys += [key]
|
||||
prev_cross_key = key
|
||||
num_cross += 1
|
||||
key = prev_transformer_key
|
||||
if share_weights is None or num_transformer <= share_weights[1]:
|
||||
key = f't{i}'
|
||||
keys += [key]
|
||||
prev_transformer_key = key
|
||||
num_transformer += 1
|
||||
return keys
|
||||
|
||||
|
||||
class Perceiver(nn.Module):
|
||||
""" Perceiver
|
||||
|
||||
Paper: `Perceiver: General Perception with Iterative Attention` - https://arxiv.org/abs/2103.03206
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chans=3, num_classes=1000, num_stages=8, cross_depths=(1,), transformer_depth=6,
|
||||
latent_dim=1024, num_latents=512, num_latent_heads=8, latent_mlp_ratio=1.0,
|
||||
cross_attn_dim=None, num_cross_heads=1, cross_mlp_ratio=1.0, share_weights=(1, 0),
|
||||
pos_embed_type='fourier', pos_embed_dim=128, data_bands=64, data_ndim=2, data_max_freq=10,
|
||||
data_spatial=False, qkv_bias=True, cross_block=None, transformer_block=None,
|
||||
cross_attn_layer=None, attn_layer=None, norm_layer=None, act_layer=None,
|
||||
drop_rate=0., drop_path_rate=0., weight_init=''):
|
||||
"""
|
||||
Args:
|
||||
in_chans (int): number of input channels
|
||||
num_classes (int): number of classes for classification head
|
||||
num_stages (int): number of stages (cross + transformer stack repeats)
|
||||
num_cross_heads (int): number of cross-attention heads
|
||||
cross_mlp_ratio (flaot): ratio of mlp hidden dim to embedding dim
|
||||
share_weights (Optiona[Tuple]): starting index of latent and transformer share (or None for no share)
|
||||
latent_dim (int):
|
||||
num_latents (int):
|
||||
num_latent_heads (int): number of latent-attention heads
|
||||
latent_mlp_ratio (float):
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
pos_embed_type (str): type of pos embed (TODO: currently defaults to fourier)
|
||||
pos_embed_dim (int): embedding dimension (for other pos-embed options besides fourier)
|
||||
data_bands (int):
|
||||
data_ndim (int):
|
||||
data_max_freq (int):
|
||||
drop_rate (float): dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
weight_init: (str): weight init scheme
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.latent_dim = latent_dim
|
||||
cross_block = cross_block or CrossBlock
|
||||
transformer_block = transformer_block or TransformerBlock
|
||||
cross_attn_layer = cross_attn_layer or CrossAttention
|
||||
attn_layer = attn_layer or Attention
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = act_layer or nn.GELU
|
||||
|
||||
self.latents = nn.Parameter(torch.zeros(num_latents, latent_dim))
|
||||
self.data_bands = data_bands
|
||||
self.data_max_freq = data_max_freq
|
||||
self.data_ndim = data_ndim
|
||||
self.data_dim = self.data_ndim * (2 * self.data_bands + 1) + in_chans
|
||||
self.data_spatial = data_spatial
|
||||
|
||||
self.blocks_cross = nn.ModuleDict()
|
||||
self.blocks_trans = nn.ModuleDict()
|
||||
self.layer_keys = get_layer_layout(cross_depths, num_stages, share_weights)
|
||||
for i, k in enumerate(self.layer_keys):
|
||||
stage_args = dict(
|
||||
qkv_bias=qkv_bias, drop=drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, act_layer=act_layer)
|
||||
if k.startswith('c'):
|
||||
self.blocks_cross[k] = cross_block(
|
||||
latent_dim=latent_dim, data_dim=self.data_dim, attn_dim=cross_attn_dim, num_heads=num_cross_heads,
|
||||
mlp_ratio=cross_mlp_ratio, attn_layer=cross_attn_layer, **stage_args)
|
||||
else:
|
||||
self.blocks_trans[k] = TransformerStack(
|
||||
depth=transformer_depth, dim=latent_dim, num_heads=num_latent_heads,
|
||||
mlp_ratio=latent_mlp_ratio, attn_layer=attn_layer, block=transformer_block, **stage_args)
|
||||
|
||||
self.norm = norm_layer(latent_dim)
|
||||
self.head = nn.Linear(latent_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.init_weights(weight_init)
|
||||
|
||||
def init_weights(self, mode=''):
|
||||
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
|
||||
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||
trunc_normal_(self.latents, std=.02)
|
||||
named_apply(partial(_init_weights, head_bias=head_bias), self)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'latents'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.latent_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME cache fourier embedding and implement positional options
|
||||
# FIXME support ndim inputs, don't assume 2D?
|
||||
data = fourier_grid(x.shape[2:], max_freq_log2=self.data_max_freq, num_bands=self.data_bands, device=x.device)
|
||||
if self.data_spatial:
|
||||
data = torch.cat([x, data.unsqueeze(0).expand(B, -1, -1, -1).permute(0, 3, 1, 2)], dim=1)
|
||||
else:
|
||||
data = torch.cat([x.permute(0, 2, 3, 1), data.unsqueeze(0).expand(B, -1, -1, -1)], dim=-1)
|
||||
data = data.reshape(B, H * W, -1)
|
||||
x = self.latents.unsqueeze(0).expand(B, -1, -1)
|
||||
for k in self.layer_keys:
|
||||
if k.startswith('c'):
|
||||
cross_blocks: CrossInterface = self.blocks_cross[k] # interface annotation for torchscript sillyness
|
||||
x = cross_blocks.forward(x, data)
|
||||
else:
|
||||
transformer: TransformerInterface = self.blocks_trans[k]
|
||||
x = transformer.forward(x)
|
||||
x = self.norm(x)
|
||||
x = x.mean(dim=1)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _init_weights(module: nn.Module, name: str = '', head_bias: float = 0.):
|
||||
""" weight initialization
|
||||
"""
|
||||
if isinstance(module, nn.Linear):
|
||||
if name.startswith('head'):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
else:
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
if 'mlp' in name:
|
||||
nn.init.normal_(module.bias, std=1e-6)
|
||||
else:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
lecun_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
||||
nn.init.zeros_(module.bias)
|
||||
nn.init.ones_(module.weight)
|
||||
|
||||
|
||||
def _create_perceiver(variant, pretrained=False, default_cfg=None, **kwargs):
|
||||
default_cfg = default_cfg or default_cfgs[variant]
|
||||
if kwargs.get('features_only', None):
|
||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||
model = build_model_with_cfg(
|
||||
Perceiver, variant, pretrained,
|
||||
default_cfg=default_cfg,
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def perceiver_ss(pretrained=False, **kwargs):
|
||||
""" Perceiver-Small (Shared)
|
||||
One initial cross attn and all transformer stacks shared. ~11M params
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
cross_depths=(1,), latent_dim=512, num_latents=256, cross_attn_dim=128, data_bands=36, **kwargs)
|
||||
model = _create_perceiver('perceiver_ss', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def perceiver_s(pretrained=False, **kwargs):
|
||||
""" Perceiver-Small
|
||||
One initial cross attn and all but first transformer stacks shared. ~20M params
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
cross_depths=(1,), latent_dim=512, num_latents=256, cross_attn_dim=128, data_bands=36,
|
||||
share_weights=(1, 1), **kwargs)
|
||||
model = _create_perceiver('perceiver_s', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def perceiver_m(pretrained=False, **kwargs):
|
||||
""" Perceiver-Medium
|
||||
Two cross attn (one per each initial transformer stack), all transformers shared. ~25M params.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
cross_depths=(1,) * 2, latent_dim=768, num_latents=384, cross_attn_dim=160, data_bands=40, **kwargs)
|
||||
model = _create_perceiver('perceiver_m', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def perceiver_m_ls(pretrained=False, **kwargs):
|
||||
""" Perceiver-Medium w/ LayerScale + Affine
|
||||
Two cross attn (one per each initial transformer stack), all transformers shared. ~25M params.
|
||||
LayerScale + Affine influenced by CaiT, LeViT, ResMLP from Facebook AI
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
cross_depths=(1,) * 2, latent_dim=768, num_latents=384, cross_attn_dim=160, data_bands=40,
|
||||
transformer_block=TransformerBlockLayerScale, cross_block=CrossBlockLayerScale,
|
||||
norm_layer=Affine, **kwargs)
|
||||
model = _create_perceiver('perceiver_m_ls', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def perceiver_l(pretrained=False, **kwargs):
|
||||
""" Perceiver-Large
|
||||
One cross attn per 8 transformer stacks. All but first cross attn shared, all transformer stacks shared.
|
||||
This variant is closest to the paper model for reported ImageNet results. ~45M params.
|
||||
"""
|
||||
model_kwargs = dict(cross_depths=1, latent_dim=1024, num_latents=512, **kwargs)
|
||||
model = _create_perceiver('perceiver_l', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
Loading…
Reference in new issue