Compare commits

...

3 Commits

Author SHA1 Message Date
Ross Wightman 55b135d5c6 Add layer scale + affine option to perceiver
3 years ago
Ross Wightman 820c262f33 Don't exclude perceivers from tests (yet)
3 years ago
Ross Wightman 77698d80a5 Initial Perceiver impl. WIP
3 years ago

@ -17,7 +17,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*']
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'perceiver*']
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures
@ -26,7 +26,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*']
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'perceiver_l*']
else:
EXCLUDE_FILTERS = []
@ -218,6 +218,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
# check first conv(s) names match default_cfg
first_conv = cfg['first_conv']
if first_conv is not None:
if isinstance(first_conv, str):
first_conv = (first_conv,)
assert isinstance(first_conv, (tuple, list))

@ -22,6 +22,7 @@ from .mobilenetv3 import *
from .nasnet import *
from .nest import *
from .nfnet import *
from .perceiver import *
from .pit import *
from .pnasnet import *
from .regnet import *
@ -36,6 +37,7 @@ from .sknet import *
from .swin_transformer import *
from .tnt import *
from .tresnet import *
from .twins import *
from .vgg import *
from .visformer import *
from .vision_transformer import *
@ -44,7 +46,6 @@ from .vovnet import *
from .xception import *
from .xception_aligned import *
from .xcit import *
from .twins import *
from .factory import create_model, split_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters

@ -8,12 +8,6 @@ Paper: `High-Performance Large-Scale Image Recognition Without Normalization`
Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
Status:
* These models are a work in progress, experiments ongoing.
* Pretrained weights for two models so far, more to come.
* Model details updated to closer match official JAX code now that it's released
* NF-ResNet, NF-RegNet-B, and NFNet-F models supported
Hacked together by / copyright Ross Wightman, 2021.
"""
import math

@ -0,0 +1,493 @@
""" Perceiver
Paper: `Perceiver: General Perception with Iterative Attention` - https://arxiv.org/abs/2103.03206
Official Deepmind code: TBD (doesn't exist yet)
Fourier feature position embedding references:
* Official NeRF impl - https://github.com/bmild/nerf
* Lucidrain's Perceiver impl - https://github.com/lucidrains/perceiver-pytorch
Status:
* Work in progress, currently running training trials with S and M models (rather slow)
Hacked together by / copyright Ross Wightman, 2021.
"""
import math
from functools import partial
from typing import List, Tuple
import torch
import torch.nn as nn
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply
from .layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_ntuple
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': None, 'classifier': 'head',
**kwargs
}
default_cfgs = {
# patch models (weights from official Google JAX impl)
'perceiver_ss': _cfg(
url='', input_size=(3, 192, 192)),
'perceiver_s': _cfg(
url='', input_size=(3, 192, 192)),
'perceiver_m': _cfg(
url=''),
'perceiver_m_ls': _cfg(
url=''),
'perceiver_l': _cfg(
url=''),
}
def fourier_encode(x, max_freq_log2: int = 8, num_bands: int = 64):
""" Fourier feature embedding.
Referenced official NeRF code and Lucidrain's PyTorch Perceiver impl.
"""
# FIXME this will likely need to change once official code / weights are available
x = x.unsqueeze(-1)
bands = 2 ** torch.linspace(0, max_freq_log2 - 1, num_bands, device=x.device, dtype=x.dtype)
x_bands = x * math.pi * bands
x = torch.cat([x, x_bands.sin(), x_bands.cos()], dim=-1)
return x
def fourier_grid(
shape: List[int], max_freq_log2: int = 8, num_bands: int = 64, device: torch.device = torch.device('cuda')):
grid = torch.stack(torch.meshgrid([torch.linspace(-1., 1., steps=s, device=device) for s in shape]), dim=-1)
enc_pos = fourier_encode(grid, max_freq_log2, num_bands)
return enc_pos.transpose(-1, -2).flatten(len(shape))
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class CrossAttention(nn.Module):
"""
"""
def __init__(self, latent_dim, data_dim, attn_dim=None, num_heads=1, qkv_bias=True, proj_drop=0.):
super().__init__()
assert latent_dim % num_heads == 0, f"dim {latent_dim} should be divided by num_heads {num_heads}."
self.latent_dim = latent_dim
self.attn_dim = attn_dim or min(latent_dim, data_dim)
self.num_heads = num_heads
head_dim = self.attn_dim // num_heads
self.scale = head_dim ** -0.5
self.q = nn.Linear(latent_dim, self.attn_dim, bias=qkv_bias)
self.kv = nn.Linear(data_dim, self.attn_dim * 2, bias=qkv_bias)
self.proj = nn.Linear(self.attn_dim, latent_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, latent, data):
B = latent.shape[0]
q = self.q(latent).reshape(B, -1, self.num_heads, self.attn_dim // self.num_heads).permute(0, 2, 1, 3)
kv = self.kv(data).reshape(B, -1, 2, self.num_heads, self.attn_dim // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, -1, self.attn_dim)
out = self.proj(out)
out = self.proj_drop(out)
return out
class Affine(nn.Module):
def __init__(self, dim):
super().__init__()
self.alpha = nn.Parameter(torch.ones((1, 1, dim)))
self.beta = nn.Parameter(torch.zeros((1, 1, dim)))
def forward(self, x):
return torch.addcmul(self.beta, self.alpha, x)
@torch.jit.interface
class CrossInterface(torch.nn.Module):
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
pass
class CrossBlock(nn.Module):
def __init__(self, latent_dim, data_dim, num_heads, attn_dim=None, mlp_ratio=4., qkv_bias=True,
drop=0., drop_path=0., attn_layer=CrossAttention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1_latent = norm_layer(latent_dim)
self.norm1_data = norm_layer(data_dim)
self.attn = attn_layer(
latent_dim, data_dim, num_heads=num_heads, attn_dim=attn_dim, qkv_bias=qkv_bias, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(latent_dim)
mlp_hidden_dim = int(latent_dim * mlp_ratio)
self.mlp = Mlp(in_features=latent_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
latent = latent + self.drop_path(self.attn(
self.norm1_latent(latent),
self.norm1_data(data),
))
latent = latent + self.drop_path(self.mlp(self.norm2(latent)))
return latent
class CrossBlockLayerScale(nn.Module):
def __init__(self, latent_dim, data_dim, num_heads, attn_dim=None, mlp_ratio=4., qkv_bias=True, init_values=1e-5,
drop=0., drop_path=0., attn_layer=CrossAttention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1_latent = norm_layer(latent_dim)
self.norm1_data = norm_layer(data_dim)
self.attn = attn_layer(
latent_dim, data_dim, num_heads=num_heads, attn_dim=attn_dim, qkv_bias=qkv_bias, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(latent_dim)
mlp_hidden_dim = int(latent_dim * mlp_ratio)
self.mlp = Mlp(in_features=latent_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.ls1 = nn.Parameter(init_values * torch.ones(latent_dim))
self.ls2 = nn.Parameter(init_values * torch.ones(latent_dim))
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
latent = latent + self.drop_path(self.ls1 * self.attn(
self.norm1_latent(latent),
self.norm1_data(data),
))
latent = latent + self.drop_path(self.ls2 * self.mlp(self.norm2(latent)))
return latent
@torch.jit.interface
class TransformerInterface(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
class TransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0.,
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class TransformerBlockLayerScale(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0.,
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.ls1 = nn.Parameter(init_values * torch.ones(dim))
self.ls2 = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
x = x + self.drop_path(self.ls1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.ls2 * self.mlp(self.norm2(x)))
return x
class TransformerStack(nn.Module):
""" A stack-o-transformers
NOTE this could have been a simple nn.Sequential but needed to wrap in module to use Interface
def for ModuleDict torchscript compat.
"""
def __init__(self, depth, dim, num_heads, mlp_ratio=4., block=None, **kwargs):
super().__init__()
block = block or TransformerBlock
self.stack = nn.Sequential(*[
block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, **kwargs)
for _ in range(depth)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.stack(x)
def get_layer_layout(cross_depths, num_stages=8, share_weights=None):
if isinstance(cross_depths, (tuple, list)):
stage_cross_depths = tuple(cross_depths)
stage_cross_depths = (stage_cross_depths + (0,) * num_stages)[:num_stages]
else:
stage_cross_depths = to_ntuple(num_stages)(cross_depths)
prev_cross_key = ''
prev_transformer_key = ''
keys = []
num_cross = 0
num_transformer = 0
for i, cd in enumerate(stage_cross_depths):
for j in range(cd):
key = prev_cross_key
if share_weights is None or num_cross <= share_weights[0]:
key = f'c{i}_{j}'
keys += [key]
prev_cross_key = key
num_cross += 1
key = prev_transformer_key
if share_weights is None or num_transformer <= share_weights[1]:
key = f't{i}'
keys += [key]
prev_transformer_key = key
num_transformer += 1
return keys
class Perceiver(nn.Module):
""" Perceiver
Paper: `Perceiver: General Perception with Iterative Attention` - https://arxiv.org/abs/2103.03206
"""
def __init__(
self, in_chans=3, num_classes=1000, num_stages=8, cross_depths=(1,), transformer_depth=6,
latent_dim=1024, num_latents=512, num_latent_heads=8, latent_mlp_ratio=1.0,
cross_attn_dim=None, num_cross_heads=1, cross_mlp_ratio=1.0, share_weights=(1, 0),
pos_embed_type='fourier', pos_embed_dim=128, data_bands=64, data_ndim=2, data_max_freq=10,
data_spatial=False, qkv_bias=True, cross_block=None, transformer_block=None,
cross_attn_layer=None, attn_layer=None, norm_layer=None, act_layer=None,
drop_rate=0., drop_path_rate=0., weight_init=''):
"""
Args:
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
num_stages (int): number of stages (cross + transformer stack repeats)
num_cross_heads (int): number of cross-attention heads
cross_mlp_ratio (flaot): ratio of mlp hidden dim to embedding dim
share_weights (Optiona[Tuple]): starting index of latent and transformer share (or None for no share)
latent_dim (int):
num_latents (int):
num_latent_heads (int): number of latent-attention heads
latent_mlp_ratio (float):
qkv_bias (bool): enable bias for qkv if True
pos_embed_type (str): type of pos embed (TODO: currently defaults to fourier)
pos_embed_dim (int): embedding dimension (for other pos-embed options besides fourier)
data_bands (int):
data_ndim (int):
data_max_freq (int):
drop_rate (float): dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer
weight_init: (str): weight init scheme
"""
super().__init__()
self.num_classes = num_classes
self.num_features = self.latent_dim = latent_dim
cross_block = cross_block or CrossBlock
transformer_block = transformer_block or TransformerBlock
cross_attn_layer = cross_attn_layer or CrossAttention
attn_layer = attn_layer or Attention
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.latents = nn.Parameter(torch.zeros(num_latents, latent_dim))
self.data_bands = data_bands
self.data_max_freq = data_max_freq
self.data_ndim = data_ndim
self.data_dim = self.data_ndim * (2 * self.data_bands + 1) + in_chans
self.data_spatial = data_spatial
self.blocks_cross = nn.ModuleDict()
self.blocks_trans = nn.ModuleDict()
self.layer_keys = get_layer_layout(cross_depths, num_stages, share_weights)
for i, k in enumerate(self.layer_keys):
stage_args = dict(
qkv_bias=qkv_bias, drop=drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, act_layer=act_layer)
if k.startswith('c'):
self.blocks_cross[k] = cross_block(
latent_dim=latent_dim, data_dim=self.data_dim, attn_dim=cross_attn_dim, num_heads=num_cross_heads,
mlp_ratio=cross_mlp_ratio, attn_layer=cross_attn_layer, **stage_args)
else:
self.blocks_trans[k] = TransformerStack(
depth=transformer_depth, dim=latent_dim, num_heads=num_latent_heads,
mlp_ratio=latent_mlp_ratio, attn_layer=attn_layer, block=transformer_block, **stage_args)
self.norm = norm_layer(latent_dim)
self.head = nn.Linear(latent_dim, num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.latents, std=.02)
named_apply(partial(_init_weights, head_bias=head_bias), self)
@torch.jit.ignore
def no_weight_decay(self):
return {'latents'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.latent_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B, C, H, W = x.shape
# FIXME cache fourier embedding and implement positional options
# FIXME support ndim inputs, don't assume 2D?
data = fourier_grid(x.shape[2:], max_freq_log2=self.data_max_freq, num_bands=self.data_bands, device=x.device)
if self.data_spatial:
data = torch.cat([x, data.unsqueeze(0).expand(B, -1, -1, -1).permute(0, 3, 1, 2)], dim=1)
else:
data = torch.cat([x.permute(0, 2, 3, 1), data.unsqueeze(0).expand(B, -1, -1, -1)], dim=-1)
data = data.reshape(B, H * W, -1)
x = self.latents.unsqueeze(0).expand(B, -1, -1)
for k in self.layer_keys:
if k.startswith('c'):
cross_blocks: CrossInterface = self.blocks_cross[k] # interface annotation for torchscript sillyness
x = cross_blocks.forward(x, data)
else:
transformer: TransformerInterface = self.blocks_trans[k]
x = transformer.forward(x)
x = self.norm(x)
x = x.mean(dim=1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _init_weights(module: nn.Module, name: str = '', head_bias: float = 0.):
""" weight initialization
"""
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight)
def _create_perceiver(variant, pretrained=False, default_cfg=None, **kwargs):
default_cfg = default_cfg or default_cfgs[variant]
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
Perceiver, variant, pretrained,
default_cfg=default_cfg,
**kwargs)
return model
@register_model
def perceiver_ss(pretrained=False, **kwargs):
""" Perceiver-Small (Shared)
One initial cross attn and all transformer stacks shared. ~11M params
"""
model_kwargs = dict(
cross_depths=(1,), latent_dim=512, num_latents=256, cross_attn_dim=128, data_bands=36, **kwargs)
model = _create_perceiver('perceiver_ss', pretrained=pretrained, **model_kwargs)
return model
@register_model
def perceiver_s(pretrained=False, **kwargs):
""" Perceiver-Small
One initial cross attn and all but first transformer stacks shared. ~20M params
"""
model_kwargs = dict(
cross_depths=(1,), latent_dim=512, num_latents=256, cross_attn_dim=128, data_bands=36,
share_weights=(1, 1), **kwargs)
model = _create_perceiver('perceiver_s', pretrained=pretrained, **model_kwargs)
return model
@register_model
def perceiver_m(pretrained=False, **kwargs):
""" Perceiver-Medium
Two cross attn (one per each initial transformer stack), all transformers shared. ~25M params.
"""
model_kwargs = dict(
cross_depths=(1,) * 2, latent_dim=768, num_latents=384, cross_attn_dim=160, data_bands=40, **kwargs)
model = _create_perceiver('perceiver_m', pretrained=pretrained, **model_kwargs)
return model
@register_model
def perceiver_m_ls(pretrained=False, **kwargs):
""" Perceiver-Medium w/ LayerScale + Affine
Two cross attn (one per each initial transformer stack), all transformers shared. ~25M params.
LayerScale + Affine influenced by CaiT, LeViT, ResMLP from Facebook AI
"""
model_kwargs = dict(
cross_depths=(1,) * 2, latent_dim=768, num_latents=384, cross_attn_dim=160, data_bands=40,
transformer_block=TransformerBlockLayerScale, cross_block=CrossBlockLayerScale,
norm_layer=Affine, **kwargs)
model = _create_perceiver('perceiver_m_ls', pretrained=pretrained, **model_kwargs)
return model
@register_model
def perceiver_l(pretrained=False, **kwargs):
""" Perceiver-Large
One cross attn per 8 transformer stacks. All but first cross attn shared, all transformer stacks shared.
This variant is closest to the paper model for reported ImageNet results. ~45M params.
"""
model_kwargs = dict(cross_depths=1, latent_dim=1024, num_latents=512, **kwargs)
model = _create_perceiver('perceiver_l', pretrained=pretrained, **model_kwargs)
return model
Loading…
Cancel
Save