Compare commits
3 Commits
Author | SHA1 | Date |
---|---|---|
Ross Wightman | 55b135d5c6 | 3 years ago |
Ross Wightman | 820c262f33 | 3 years ago |
Ross Wightman | 77698d80a5 | 3 years ago |
@ -0,0 +1,493 @@
|
|||||||
|
""" Perceiver
|
||||||
|
|
||||||
|
Paper: `Perceiver: General Perception with Iterative Attention` - https://arxiv.org/abs/2103.03206
|
||||||
|
|
||||||
|
Official Deepmind code: TBD (doesn't exist yet)
|
||||||
|
|
||||||
|
Fourier feature position embedding references:
|
||||||
|
* Official NeRF impl - https://github.com/bmild/nerf
|
||||||
|
* Lucidrain's Perceiver impl - https://github.com/lucidrains/perceiver-pytorch
|
||||||
|
|
||||||
|
Status:
|
||||||
|
* Work in progress, currently running training trials with S and M models (rather slow)
|
||||||
|
|
||||||
|
Hacked together by / copyright Ross Wightman, 2021.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
|
from .helpers import build_model_with_cfg, named_apply
|
||||||
|
from .layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_ntuple
|
||||||
|
from .registry import register_model
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
|
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||||
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||||
|
'first_conv': None, 'classifier': 'head',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
# patch models (weights from official Google JAX impl)
|
||||||
|
'perceiver_ss': _cfg(
|
||||||
|
url='', input_size=(3, 192, 192)),
|
||||||
|
'perceiver_s': _cfg(
|
||||||
|
url='', input_size=(3, 192, 192)),
|
||||||
|
'perceiver_m': _cfg(
|
||||||
|
url=''),
|
||||||
|
'perceiver_m_ls': _cfg(
|
||||||
|
url=''),
|
||||||
|
'perceiver_l': _cfg(
|
||||||
|
url=''),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def fourier_encode(x, max_freq_log2: int = 8, num_bands: int = 64):
|
||||||
|
""" Fourier feature embedding.
|
||||||
|
Referenced official NeRF code and Lucidrain's PyTorch Perceiver impl.
|
||||||
|
"""
|
||||||
|
# FIXME this will likely need to change once official code / weights are available
|
||||||
|
x = x.unsqueeze(-1)
|
||||||
|
bands = 2 ** torch.linspace(0, max_freq_log2 - 1, num_bands, device=x.device, dtype=x.dtype)
|
||||||
|
x_bands = x * math.pi * bands
|
||||||
|
x = torch.cat([x, x_bands.sin(), x_bands.cos()], dim=-1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def fourier_grid(
|
||||||
|
shape: List[int], max_freq_log2: int = 8, num_bands: int = 64, device: torch.device = torch.device('cuda')):
|
||||||
|
grid = torch.stack(torch.meshgrid([torch.linspace(-1., 1., steps=s, device=device) for s in shape]), dim=-1)
|
||||||
|
enc_pos = fourier_encode(grid, max_freq_log2, num_bands)
|
||||||
|
return enc_pos.transpose(-1, -2).flatten(len(shape))
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=False, proj_drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = head_dim ** -0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.proj = nn.Linear(dim, dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
def __init__(self, latent_dim, data_dim, attn_dim=None, num_heads=1, qkv_bias=True, proj_drop=0.):
|
||||||
|
super().__init__()
|
||||||
|
assert latent_dim % num_heads == 0, f"dim {latent_dim} should be divided by num_heads {num_heads}."
|
||||||
|
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
self.attn_dim = attn_dim or min(latent_dim, data_dim)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = self.attn_dim // num_heads
|
||||||
|
self.scale = head_dim ** -0.5
|
||||||
|
|
||||||
|
self.q = nn.Linear(latent_dim, self.attn_dim, bias=qkv_bias)
|
||||||
|
self.kv = nn.Linear(data_dim, self.attn_dim * 2, bias=qkv_bias)
|
||||||
|
self.proj = nn.Linear(self.attn_dim, latent_dim)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, latent, data):
|
||||||
|
B = latent.shape[0]
|
||||||
|
q = self.q(latent).reshape(B, -1, self.num_heads, self.attn_dim // self.num_heads).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
kv = self.kv(data).reshape(B, -1, 2, self.num_heads, self.attn_dim // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
k, v = kv[0], kv[1]
|
||||||
|
|
||||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = (attn @ v).transpose(1, 2).reshape(B, -1, self.attn_dim)
|
||||||
|
out = self.proj(out)
|
||||||
|
out = self.proj_drop(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Affine(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha = nn.Parameter(torch.ones((1, 1, dim)))
|
||||||
|
self.beta = nn.Parameter(torch.zeros((1, 1, dim)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.addcmul(self.beta, self.alpha, x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.interface
|
||||||
|
class CrossInterface(torch.nn.Module):
|
||||||
|
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CrossBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, latent_dim, data_dim, num_heads, attn_dim=None, mlp_ratio=4., qkv_bias=True,
|
||||||
|
drop=0., drop_path=0., attn_layer=CrossAttention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1_latent = norm_layer(latent_dim)
|
||||||
|
self.norm1_data = norm_layer(data_dim)
|
||||||
|
self.attn = attn_layer(
|
||||||
|
latent_dim, data_dim, num_heads=num_heads, attn_dim=attn_dim, qkv_bias=qkv_bias, proj_drop=drop)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(latent_dim)
|
||||||
|
mlp_hidden_dim = int(latent_dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(in_features=latent_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
|
||||||
|
latent = latent + self.drop_path(self.attn(
|
||||||
|
self.norm1_latent(latent),
|
||||||
|
self.norm1_data(data),
|
||||||
|
))
|
||||||
|
latent = latent + self.drop_path(self.mlp(self.norm2(latent)))
|
||||||
|
return latent
|
||||||
|
|
||||||
|
|
||||||
|
class CrossBlockLayerScale(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, latent_dim, data_dim, num_heads, attn_dim=None, mlp_ratio=4., qkv_bias=True, init_values=1e-5,
|
||||||
|
drop=0., drop_path=0., attn_layer=CrossAttention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1_latent = norm_layer(latent_dim)
|
||||||
|
self.norm1_data = norm_layer(data_dim)
|
||||||
|
self.attn = attn_layer(
|
||||||
|
latent_dim, data_dim, num_heads=num_heads, attn_dim=attn_dim, qkv_bias=qkv_bias, proj_drop=drop)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(latent_dim)
|
||||||
|
mlp_hidden_dim = int(latent_dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(in_features=latent_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||||
|
self.ls1 = nn.Parameter(init_values * torch.ones(latent_dim))
|
||||||
|
self.ls2 = nn.Parameter(init_values * torch.ones(latent_dim))
|
||||||
|
|
||||||
|
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
|
||||||
|
latent = latent + self.drop_path(self.ls1 * self.attn(
|
||||||
|
self.norm1_latent(latent),
|
||||||
|
self.norm1_data(data),
|
||||||
|
))
|
||||||
|
latent = latent + self.drop_path(self.ls2 * self.mlp(self.norm2(latent)))
|
||||||
|
return latent
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.interface
|
||||||
|
class TransformerInterface(torch.nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0.,
|
||||||
|
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=drop)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||||
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerBlockLayerScale(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0.,
|
||||||
|
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=drop)
|
||||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||||
|
self.ls1 = nn.Parameter(init_values * torch.ones(dim))
|
||||||
|
self.ls2 = nn.Parameter(init_values * torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.drop_path(self.ls1 * self.attn(self.norm1(x)))
|
||||||
|
x = x + self.drop_path(self.ls2 * self.mlp(self.norm2(x)))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerStack(nn.Module):
|
||||||
|
""" A stack-o-transformers
|
||||||
|
NOTE this could have been a simple nn.Sequential but needed to wrap in module to use Interface
|
||||||
|
def for ModuleDict torchscript compat.
|
||||||
|
"""
|
||||||
|
def __init__(self, depth, dim, num_heads, mlp_ratio=4., block=None, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
block = block or TransformerBlock
|
||||||
|
self.stack = nn.Sequential(*[
|
||||||
|
block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, **kwargs)
|
||||||
|
for _ in range(depth)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.stack(x)
|
||||||
|
|
||||||
|
|
||||||
|
def get_layer_layout(cross_depths, num_stages=8, share_weights=None):
|
||||||
|
if isinstance(cross_depths, (tuple, list)):
|
||||||
|
stage_cross_depths = tuple(cross_depths)
|
||||||
|
stage_cross_depths = (stage_cross_depths + (0,) * num_stages)[:num_stages]
|
||||||
|
else:
|
||||||
|
stage_cross_depths = to_ntuple(num_stages)(cross_depths)
|
||||||
|
prev_cross_key = ''
|
||||||
|
prev_transformer_key = ''
|
||||||
|
keys = []
|
||||||
|
num_cross = 0
|
||||||
|
num_transformer = 0
|
||||||
|
for i, cd in enumerate(stage_cross_depths):
|
||||||
|
for j in range(cd):
|
||||||
|
key = prev_cross_key
|
||||||
|
if share_weights is None or num_cross <= share_weights[0]:
|
||||||
|
key = f'c{i}_{j}'
|
||||||
|
keys += [key]
|
||||||
|
prev_cross_key = key
|
||||||
|
num_cross += 1
|
||||||
|
key = prev_transformer_key
|
||||||
|
if share_weights is None or num_transformer <= share_weights[1]:
|
||||||
|
key = f't{i}'
|
||||||
|
keys += [key]
|
||||||
|
prev_transformer_key = key
|
||||||
|
num_transformer += 1
|
||||||
|
return keys
|
||||||
|
|
||||||
|
|
||||||
|
class Perceiver(nn.Module):
|
||||||
|
""" Perceiver
|
||||||
|
|
||||||
|
Paper: `Perceiver: General Perception with Iterative Attention` - https://arxiv.org/abs/2103.03206
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_chans=3, num_classes=1000, num_stages=8, cross_depths=(1,), transformer_depth=6,
|
||||||
|
latent_dim=1024, num_latents=512, num_latent_heads=8, latent_mlp_ratio=1.0,
|
||||||
|
cross_attn_dim=None, num_cross_heads=1, cross_mlp_ratio=1.0, share_weights=(1, 0),
|
||||||
|
pos_embed_type='fourier', pos_embed_dim=128, data_bands=64, data_ndim=2, data_max_freq=10,
|
||||||
|
data_spatial=False, qkv_bias=True, cross_block=None, transformer_block=None,
|
||||||
|
cross_attn_layer=None, attn_layer=None, norm_layer=None, act_layer=None,
|
||||||
|
drop_rate=0., drop_path_rate=0., weight_init=''):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_chans (int): number of input channels
|
||||||
|
num_classes (int): number of classes for classification head
|
||||||
|
num_stages (int): number of stages (cross + transformer stack repeats)
|
||||||
|
num_cross_heads (int): number of cross-attention heads
|
||||||
|
cross_mlp_ratio (flaot): ratio of mlp hidden dim to embedding dim
|
||||||
|
share_weights (Optiona[Tuple]): starting index of latent and transformer share (or None for no share)
|
||||||
|
latent_dim (int):
|
||||||
|
num_latents (int):
|
||||||
|
num_latent_heads (int): number of latent-attention heads
|
||||||
|
latent_mlp_ratio (float):
|
||||||
|
qkv_bias (bool): enable bias for qkv if True
|
||||||
|
pos_embed_type (str): type of pos embed (TODO: currently defaults to fourier)
|
||||||
|
pos_embed_dim (int): embedding dimension (for other pos-embed options besides fourier)
|
||||||
|
data_bands (int):
|
||||||
|
data_ndim (int):
|
||||||
|
data_max_freq (int):
|
||||||
|
drop_rate (float): dropout rate
|
||||||
|
drop_path_rate (float): stochastic depth rate
|
||||||
|
norm_layer: (nn.Module): normalization layer
|
||||||
|
weight_init: (str): weight init scheme
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.num_features = self.latent_dim = latent_dim
|
||||||
|
cross_block = cross_block or CrossBlock
|
||||||
|
transformer_block = transformer_block or TransformerBlock
|
||||||
|
cross_attn_layer = cross_attn_layer or CrossAttention
|
||||||
|
attn_layer = attn_layer or Attention
|
||||||
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
act_layer = act_layer or nn.GELU
|
||||||
|
|
||||||
|
self.latents = nn.Parameter(torch.zeros(num_latents, latent_dim))
|
||||||
|
self.data_bands = data_bands
|
||||||
|
self.data_max_freq = data_max_freq
|
||||||
|
self.data_ndim = data_ndim
|
||||||
|
self.data_dim = self.data_ndim * (2 * self.data_bands + 1) + in_chans
|
||||||
|
self.data_spatial = data_spatial
|
||||||
|
|
||||||
|
self.blocks_cross = nn.ModuleDict()
|
||||||
|
self.blocks_trans = nn.ModuleDict()
|
||||||
|
self.layer_keys = get_layer_layout(cross_depths, num_stages, share_weights)
|
||||||
|
for i, k in enumerate(self.layer_keys):
|
||||||
|
stage_args = dict(
|
||||||
|
qkv_bias=qkv_bias, drop=drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, act_layer=act_layer)
|
||||||
|
if k.startswith('c'):
|
||||||
|
self.blocks_cross[k] = cross_block(
|
||||||
|
latent_dim=latent_dim, data_dim=self.data_dim, attn_dim=cross_attn_dim, num_heads=num_cross_heads,
|
||||||
|
mlp_ratio=cross_mlp_ratio, attn_layer=cross_attn_layer, **stage_args)
|
||||||
|
else:
|
||||||
|
self.blocks_trans[k] = TransformerStack(
|
||||||
|
depth=transformer_depth, dim=latent_dim, num_heads=num_latent_heads,
|
||||||
|
mlp_ratio=latent_mlp_ratio, attn_layer=attn_layer, block=transformer_block, **stage_args)
|
||||||
|
|
||||||
|
self.norm = norm_layer(latent_dim)
|
||||||
|
self.head = nn.Linear(latent_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
self.init_weights(weight_init)
|
||||||
|
|
||||||
|
def init_weights(self, mode=''):
|
||||||
|
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
|
||||||
|
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
||||||
|
trunc_normal_(self.latents, std=.02)
|
||||||
|
named_apply(partial(_init_weights, head_bias=head_bias), self)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay(self):
|
||||||
|
return {'latents'}
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.head
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool=''):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.head = nn.Linear(self.latent_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
# FIXME cache fourier embedding and implement positional options
|
||||||
|
# FIXME support ndim inputs, don't assume 2D?
|
||||||
|
data = fourier_grid(x.shape[2:], max_freq_log2=self.data_max_freq, num_bands=self.data_bands, device=x.device)
|
||||||
|
if self.data_spatial:
|
||||||
|
data = torch.cat([x, data.unsqueeze(0).expand(B, -1, -1, -1).permute(0, 3, 1, 2)], dim=1)
|
||||||
|
else:
|
||||||
|
data = torch.cat([x.permute(0, 2, 3, 1), data.unsqueeze(0).expand(B, -1, -1, -1)], dim=-1)
|
||||||
|
data = data.reshape(B, H * W, -1)
|
||||||
|
x = self.latents.unsqueeze(0).expand(B, -1, -1)
|
||||||
|
for k in self.layer_keys:
|
||||||
|
if k.startswith('c'):
|
||||||
|
cross_blocks: CrossInterface = self.blocks_cross[k] # interface annotation for torchscript sillyness
|
||||||
|
x = cross_blocks.forward(x, data)
|
||||||
|
else:
|
||||||
|
transformer: TransformerInterface = self.blocks_trans[k]
|
||||||
|
x = transformer.forward(x)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.mean(dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _init_weights(module: nn.Module, name: str = '', head_bias: float = 0.):
|
||||||
|
""" weight initialization
|
||||||
|
"""
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
if name.startswith('head'):
|
||||||
|
nn.init.zeros_(module.weight)
|
||||||
|
nn.init.constant_(module.bias, head_bias)
|
||||||
|
else:
|
||||||
|
nn.init.xavier_uniform_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
if 'mlp' in name:
|
||||||
|
nn.init.normal_(module.bias, std=1e-6)
|
||||||
|
else:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, nn.Conv2d):
|
||||||
|
lecun_normal_(module.weight)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
nn.init.ones_(module.weight)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_perceiver(variant, pretrained=False, default_cfg=None, **kwargs):
|
||||||
|
default_cfg = default_cfg or default_cfgs[variant]
|
||||||
|
if kwargs.get('features_only', None):
|
||||||
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||||
|
model = build_model_with_cfg(
|
||||||
|
Perceiver, variant, pretrained,
|
||||||
|
default_cfg=default_cfg,
|
||||||
|
**kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def perceiver_ss(pretrained=False, **kwargs):
|
||||||
|
""" Perceiver-Small (Shared)
|
||||||
|
One initial cross attn and all transformer stacks shared. ~11M params
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
cross_depths=(1,), latent_dim=512, num_latents=256, cross_attn_dim=128, data_bands=36, **kwargs)
|
||||||
|
model = _create_perceiver('perceiver_ss', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def perceiver_s(pretrained=False, **kwargs):
|
||||||
|
""" Perceiver-Small
|
||||||
|
One initial cross attn and all but first transformer stacks shared. ~20M params
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
cross_depths=(1,), latent_dim=512, num_latents=256, cross_attn_dim=128, data_bands=36,
|
||||||
|
share_weights=(1, 1), **kwargs)
|
||||||
|
model = _create_perceiver('perceiver_s', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def perceiver_m(pretrained=False, **kwargs):
|
||||||
|
""" Perceiver-Medium
|
||||||
|
Two cross attn (one per each initial transformer stack), all transformers shared. ~25M params.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
cross_depths=(1,) * 2, latent_dim=768, num_latents=384, cross_attn_dim=160, data_bands=40, **kwargs)
|
||||||
|
model = _create_perceiver('perceiver_m', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def perceiver_m_ls(pretrained=False, **kwargs):
|
||||||
|
""" Perceiver-Medium w/ LayerScale + Affine
|
||||||
|
Two cross attn (one per each initial transformer stack), all transformers shared. ~25M params.
|
||||||
|
LayerScale + Affine influenced by CaiT, LeViT, ResMLP from Facebook AI
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
cross_depths=(1,) * 2, latent_dim=768, num_latents=384, cross_attn_dim=160, data_bands=40,
|
||||||
|
transformer_block=TransformerBlockLayerScale, cross_block=CrossBlockLayerScale,
|
||||||
|
norm_layer=Affine, **kwargs)
|
||||||
|
model = _create_perceiver('perceiver_m_ls', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def perceiver_l(pretrained=False, **kwargs):
|
||||||
|
""" Perceiver-Large
|
||||||
|
One cross attn per 8 transformer stacks. All but first cross attn shared, all transformer stacks shared.
|
||||||
|
This variant is closest to the paper model for reported ImageNet results. ~45M params.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(cross_depths=1, latent_dim=1024, num_latents=512, **kwargs)
|
||||||
|
model = _create_perceiver('perceiver_l', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
Loading…
Reference in new issue