|
|
@ -14,17 +14,15 @@ Status:
|
|
|
|
Hacked together by / copyright Ross Wightman, 2021.
|
|
|
|
Hacked together by / copyright Ross Wightman, 2021.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
import math
|
|
|
|
import math
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
from functools import partial
|
|
|
|
from functools import partial
|
|
|
|
from typing import List, Tuple
|
|
|
|
from typing import List, Tuple
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|
|
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|
|
|
from .helpers import build_model_with_cfg, named_apply
|
|
|
|
from .helpers import build_model_with_cfg, named_apply
|
|
|
|
from .layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_ntuple, ConvBnAct, LayerNorm2d
|
|
|
|
from .layers import Mlp, DropPath, trunc_normal_, lecun_normal_, to_ntuple
|
|
|
|
from .registry import register_model
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -47,6 +45,8 @@ default_cfgs = {
|
|
|
|
url='', input_size=(3, 192, 192)),
|
|
|
|
url='', input_size=(3, 192, 192)),
|
|
|
|
'perceiver_m': _cfg(
|
|
|
|
'perceiver_m': _cfg(
|
|
|
|
url=''),
|
|
|
|
url=''),
|
|
|
|
|
|
|
|
'perceiver_m_ls': _cfg(
|
|
|
|
|
|
|
|
url=''),
|
|
|
|
'perceiver_l': _cfg(
|
|
|
|
'perceiver_l': _cfg(
|
|
|
|
url=''),
|
|
|
|
url=''),
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -130,6 +130,16 @@ class CrossAttention(nn.Module):
|
|
|
|
return out
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Affine(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, dim):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.alpha = nn.Parameter(torch.ones((1, 1, dim)))
|
|
|
|
|
|
|
|
self.beta = nn.Parameter(torch.zeros((1, 1, dim)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
return torch.addcmul(self.beta, self.alpha, x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.interface
|
|
|
|
@torch.jit.interface
|
|
|
|
class CrossInterface(torch.nn.Module):
|
|
|
|
class CrossInterface(torch.nn.Module):
|
|
|
|
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
|
|
|
|
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
|
|
|
@ -159,6 +169,31 @@ class CrossBlock(nn.Module):
|
|
|
|
return latent
|
|
|
|
return latent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossBlockLayerScale(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, latent_dim, data_dim, num_heads, attn_dim=None, mlp_ratio=4., qkv_bias=True, init_values=1e-5,
|
|
|
|
|
|
|
|
drop=0., drop_path=0., attn_layer=CrossAttention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.norm1_latent = norm_layer(latent_dim)
|
|
|
|
|
|
|
|
self.norm1_data = norm_layer(data_dim)
|
|
|
|
|
|
|
|
self.attn = attn_layer(
|
|
|
|
|
|
|
|
latent_dim, data_dim, num_heads=num_heads, attn_dim=attn_dim, qkv_bias=qkv_bias, proj_drop=drop)
|
|
|
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
self.norm2 = norm_layer(latent_dim)
|
|
|
|
|
|
|
|
mlp_hidden_dim = int(latent_dim * mlp_ratio)
|
|
|
|
|
|
|
|
self.mlp = Mlp(in_features=latent_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
|
|
|
|
|
|
self.ls1 = nn.Parameter(init_values * torch.ones(latent_dim))
|
|
|
|
|
|
|
|
self.ls2 = nn.Parameter(init_values * torch.ones(latent_dim))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, latent: torch.Tensor, data: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
latent = latent + self.drop_path(self.ls1 * self.attn(
|
|
|
|
|
|
|
|
self.norm1_latent(latent),
|
|
|
|
|
|
|
|
self.norm1_data(data),
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
latent = latent + self.drop_path(self.ls2 * self.mlp(self.norm2(latent)))
|
|
|
|
|
|
|
|
return latent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.interface
|
|
|
|
@torch.jit.interface
|
|
|
|
class TransformerInterface(torch.nn.Module):
|
|
|
|
class TransformerInterface(torch.nn.Module):
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
@ -167,7 +202,7 @@ class TransformerInterface(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.,
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0.,
|
|
|
|
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
|
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
@ -182,15 +217,37 @@ class TransformerBlock(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerBlockLayerScale(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, init_values=1e-5, drop=0.,
|
|
|
|
|
|
|
|
drop_path=0., attn_layer=Attention, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
|
|
|
|
self.attn = attn_layer(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_drop=drop)
|
|
|
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
|
|
|
|
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
|
|
|
|
|
|
|
self.ls1 = nn.Parameter(init_values * torch.ones(dim))
|
|
|
|
|
|
|
|
self.ls2 = nn.Parameter(init_values * torch.ones(dim))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
x = x + self.drop_path(self.ls1 * self.attn(self.norm1(x)))
|
|
|
|
|
|
|
|
x = x + self.drop_path(self.ls2 * self.mlp(self.norm2(x)))
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerStack(nn.Module):
|
|
|
|
class TransformerStack(nn.Module):
|
|
|
|
""" A stack-o-transformers
|
|
|
|
""" A stack-o-transformers
|
|
|
|
NOTE this could have been a simple nn.Sequential but needed to wrap in module to use Interface
|
|
|
|
NOTE this could have been a simple nn.Sequential but needed to wrap in module to use Interface
|
|
|
|
def for ModuleDict torchscript compat.
|
|
|
|
def for ModuleDict torchscript compat.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self, depth, dim, num_heads, mlp_ratio=4., **kwargs):
|
|
|
|
def __init__(self, depth, dim, num_heads, mlp_ratio=4., block=None, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
block = block or TransformerBlock
|
|
|
|
self.stack = nn.Sequential(*[
|
|
|
|
self.stack = nn.Sequential(*[
|
|
|
|
TransformerBlock(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, **kwargs) for _ in range(depth)])
|
|
|
|
block(dim=dim, num_heads=num_heads, mlp_ratio=mlp_ratio, **kwargs)
|
|
|
|
|
|
|
|
for _ in range(depth)
|
|
|
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
return self.stack(x)
|
|
|
|
return self.stack(x)
|
|
|
@ -235,7 +292,8 @@ class Perceiver(nn.Module):
|
|
|
|
latent_dim=1024, num_latents=512, num_latent_heads=8, latent_mlp_ratio=1.0,
|
|
|
|
latent_dim=1024, num_latents=512, num_latent_heads=8, latent_mlp_ratio=1.0,
|
|
|
|
cross_attn_dim=None, num_cross_heads=1, cross_mlp_ratio=1.0, share_weights=(1, 0),
|
|
|
|
cross_attn_dim=None, num_cross_heads=1, cross_mlp_ratio=1.0, share_weights=(1, 0),
|
|
|
|
pos_embed_type='fourier', pos_embed_dim=128, data_bands=64, data_ndim=2, data_max_freq=10,
|
|
|
|
pos_embed_type='fourier', pos_embed_dim=128, data_bands=64, data_ndim=2, data_max_freq=10,
|
|
|
|
data_spatial=False, qkv_bias=True, cross_attn_layer=None, attn_layer=None, norm_layer=None, act_layer=None,
|
|
|
|
data_spatial=False, qkv_bias=True, cross_block=None, transformer_block=None,
|
|
|
|
|
|
|
|
cross_attn_layer=None, attn_layer=None, norm_layer=None, act_layer=None,
|
|
|
|
drop_rate=0., drop_path_rate=0., weight_init=''):
|
|
|
|
drop_rate=0., drop_path_rate=0., weight_init=''):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
@ -263,6 +321,8 @@ class Perceiver(nn.Module):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_features = self.latent_dim = latent_dim
|
|
|
|
self.num_features = self.latent_dim = latent_dim
|
|
|
|
|
|
|
|
cross_block = cross_block or CrossBlock
|
|
|
|
|
|
|
|
transformer_block = transformer_block or TransformerBlock
|
|
|
|
cross_attn_layer = cross_attn_layer or CrossAttention
|
|
|
|
cross_attn_layer = cross_attn_layer or CrossAttention
|
|
|
|
attn_layer = attn_layer or Attention
|
|
|
|
attn_layer = attn_layer or Attention
|
|
|
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
|
|
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
|
|
@ -282,13 +342,13 @@ class Perceiver(nn.Module):
|
|
|
|
stage_args = dict(
|
|
|
|
stage_args = dict(
|
|
|
|
qkv_bias=qkv_bias, drop=drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, act_layer=act_layer)
|
|
|
|
qkv_bias=qkv_bias, drop=drop_rate, drop_path=drop_path_rate, norm_layer=norm_layer, act_layer=act_layer)
|
|
|
|
if k.startswith('c'):
|
|
|
|
if k.startswith('c'):
|
|
|
|
self.blocks_cross[k] = CrossBlock(
|
|
|
|
self.blocks_cross[k] = cross_block(
|
|
|
|
latent_dim=latent_dim, data_dim=self.data_dim, attn_dim=cross_attn_dim, num_heads=num_cross_heads,
|
|
|
|
latent_dim=latent_dim, data_dim=self.data_dim, attn_dim=cross_attn_dim, num_heads=num_cross_heads,
|
|
|
|
mlp_ratio=cross_mlp_ratio, attn_layer=cross_attn_layer, **stage_args)
|
|
|
|
mlp_ratio=cross_mlp_ratio, attn_layer=cross_attn_layer, **stage_args)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.blocks_trans[k] = TransformerStack(
|
|
|
|
self.blocks_trans[k] = TransformerStack(
|
|
|
|
depth=transformer_depth, dim=latent_dim, num_heads=num_latent_heads,
|
|
|
|
depth=transformer_depth, dim=latent_dim, num_heads=num_latent_heads,
|
|
|
|
mlp_ratio=latent_mlp_ratio, attn_layer=attn_layer, **stage_args)
|
|
|
|
mlp_ratio=latent_mlp_ratio, attn_layer=attn_layer, block=transformer_block, **stage_args)
|
|
|
|
|
|
|
|
|
|
|
|
self.norm = norm_layer(latent_dim)
|
|
|
|
self.norm = norm_layer(latent_dim)
|
|
|
|
self.head = nn.Linear(latent_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
self.head = nn.Linear(latent_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
@ -408,6 +468,20 @@ def perceiver_m(pretrained=False, **kwargs):
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def perceiver_m_ls(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" Perceiver-Medium w/ LayerScale + Affine
|
|
|
|
|
|
|
|
Two cross attn (one per each initial transformer stack), all transformers shared. ~25M params.
|
|
|
|
|
|
|
|
LayerScale + Affine influenced by CaiT, LeViT, ResMLP from Facebook AI
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
|
|
|
cross_depths=(1,) * 2, latent_dim=768, num_latents=384, cross_attn_dim=160, data_bands=40,
|
|
|
|
|
|
|
|
transformer_block=TransformerBlockLayerScale, cross_block=CrossBlockLayerScale,
|
|
|
|
|
|
|
|
norm_layer=Affine, **kwargs)
|
|
|
|
|
|
|
|
model = _create_perceiver('perceiver_m_ls', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def perceiver_l(pretrained=False, **kwargs):
|
|
|
|
def perceiver_l(pretrained=False, **kwargs):
|
|
|
|
""" Perceiver-Large
|
|
|
|
""" Perceiver-Large
|
|
|
|