Add initial AttentionPool2d that's being trialed. Fix comment and still trying to improve reliability of sgd test.
parent
76881d207b
commit
5f12de4875
@ -0,0 +1,182 @@
|
||||
""" Attention Pool 2D
|
||||
|
||||
Implementations of 2D spatial feature pooling using multi-head attention instead of average pool.
|
||||
|
||||
Based on idea in CLIP by OpenAI, licensed Apache 2.0
|
||||
https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from typing import List, Union, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .helpers import to_2tuple
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
def rot(x):
|
||||
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
|
||||
|
||||
|
||||
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
|
||||
return x * cos_emb + rot(x) * sin_emb
|
||||
|
||||
|
||||
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = [x]
|
||||
return [t * cos_emb + rot(t) * sin_emb for t in x]
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
""" Rotary position embedding
|
||||
|
||||
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
|
||||
been well tested, and will likely change. It will be moved to its own file.
|
||||
|
||||
The following impl/resources were referenced for this impl:
|
||||
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
|
||||
* https://blog.eleuther.ai/rotary-embeddings/
|
||||
"""
|
||||
def __init__(self, dim, max_freq=4):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.register_buffer('bands', 2 ** torch.linspace(0., max_freq - 1, self.dim // 4), persistent=False)
|
||||
|
||||
def get_embed(self, shape: torch.Size, device: torch.device = None, dtype: torch.dtype = None):
|
||||
"""
|
||||
NOTE: shape arg should include spatial dim only
|
||||
"""
|
||||
device = device or self.bands.device
|
||||
dtype = dtype or self.bands.dtype
|
||||
if not isinstance(shape, torch.Size):
|
||||
shape = torch.Size(shape)
|
||||
N = shape.numel()
|
||||
grid = torch.stack(torch.meshgrid(
|
||||
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in shape]), dim=-1).unsqueeze(-1)
|
||||
emb = grid * math.pi * self.bands
|
||||
sin = emb.sin().reshape(N, -1).repeat_interleave(2, -1)
|
||||
cos = emb.cos().reshape(N, -1).repeat_interleave(2, -1)
|
||||
return sin, cos
|
||||
|
||||
def forward(self, x):
|
||||
# assuming channel-first tensor where spatial dim are >= 2
|
||||
sin_emb, cos_emb = self.get_embed(x.shape[2:])
|
||||
return apply_rot_embed(x, sin_emb, cos_emb)
|
||||
|
||||
|
||||
class RotAttentionPool2d(nn.Module):
|
||||
""" Attention based 2D feature pooling w/ rotary (relative) pos embedding.
|
||||
This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
|
||||
|
||||
Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed.
|
||||
https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
|
||||
|
||||
NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
|
||||
train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int = None,
|
||||
embed_dim: int = None,
|
||||
num_heads: int = 4,
|
||||
qkv_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
embed_dim = embed_dim or in_features
|
||||
out_features = out_features or in_features
|
||||
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(embed_dim, out_features)
|
||||
self.num_heads = num_heads
|
||||
assert embed_dim % num_heads == 0
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
self.pos_embed = RotaryEmbedding(self.head_dim)
|
||||
|
||||
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
|
||||
nn.init.zeros_(self.qkv.bias)
|
||||
|
||||
def forward(self, x):
|
||||
B, _, H, W = x.shape
|
||||
N = H * W
|
||||
sin_emb, cos_emb = self.pos_embed.get_embed(x.shape[2:])
|
||||
x = x.reshape(B, -1, N).permute(0, 2, 1)
|
||||
|
||||
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
|
||||
|
||||
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = x[0], x[1], x[2]
|
||||
|
||||
qc, q = q[:, :, :1], q[:, :, 1:]
|
||||
q = apply_rot_embed(q, sin_emb, cos_emb)
|
||||
q = torch.cat([qc, q], dim=2)
|
||||
|
||||
kc, k = k[:, :, :1], k[:, :, 1:]
|
||||
k = apply_rot_embed(k, sin_emb, cos_emb)
|
||||
k = torch.cat([kc, k], dim=2)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
|
||||
x = self.proj(x)
|
||||
return x[:, 0]
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
""" Attention based 2D feature pooling w/ learned (absolute) pos embedding.
|
||||
This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
|
||||
|
||||
It was based on impl in CLIP by OpenAI
|
||||
https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
|
||||
|
||||
NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
feat_size: Union[int, Tuple[int, int]],
|
||||
out_features: int = None,
|
||||
embed_dim: int = None,
|
||||
num_heads: int = 4,
|
||||
qkv_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
embed_dim = embed_dim or in_features
|
||||
out_features = out_features or in_features
|
||||
assert embed_dim % num_heads == 0
|
||||
self.feat_size = to_2tuple(feat_size)
|
||||
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(embed_dim, out_features)
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
spatial_dim = self.feat_size[0] * self.feat_size[1]
|
||||
self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))
|
||||
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
|
||||
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
|
||||
nn.init.zeros_(self.qkv.bias)
|
||||
|
||||
def forward(self, x):
|
||||
B, _, H, W = x.shape
|
||||
N = H * W
|
||||
assert self.feat_size[0] == H
|
||||
assert self.feat_size[1] == W
|
||||
x = x.reshape(B, -1, N).permute(0, 2, 1)
|
||||
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
|
||||
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
||||
|
||||
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = x[0], x[1], x[2]
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
|
||||
x = self.proj(x)
|
||||
return x[:, 0]
|
Loading…
Reference in new issue