|
|
|
""" Halo Self Attention
|
|
|
|
|
|
|
|
Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
|
|
|
|
- https://arxiv.org/abs/2103.12731
|
|
|
|
|
|
|
|
@misc{2103.12731,
|
|
|
|
Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and
|
|
|
|
Jonathon Shlens},
|
|
|
|
Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones},
|
|
|
|
Year = {2021},
|
|
|
|
}
|
|
|
|
|
|
|
|
Status:
|
|
|
|
This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me.
|
|
|
|
|
|
|
|
Trying to match the 'H1' variant in the paper, my parameter counts are 2M less and the model
|
|
|
|
is extremely slow. Something isn't right. However, the models do appear to train and experimental
|
|
|
|
variants with attn in C4 and/or C5 stages are tolerable speed.
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2021 Ross Wightman
|
|
|
|
"""
|
|
|
|
from typing import Tuple, List
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from .weight_init import trunc_normal_
|
|
|
|
|
|
|
|
|
|
|
|
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
|
|
|
|
""" Compute relative logits along one dimension
|
|
|
|
|
|
|
|
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
|
|
|
|
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
|
|
|
|
|
|
|
|
Args:
|
|
|
|
q: (batch, height, width, dim)
|
|
|
|
rel_k: (2 * window - 1, dim)
|
|
|
|
permute_mask: permute output dim according to this
|
|
|
|
"""
|
|
|
|
B, H, W, dim = q.shape
|
|
|
|
rel_size = rel_k.shape[0]
|
|
|
|
win_size = (rel_size + 1) // 2
|
|
|
|
|
|
|
|
x = (q @ rel_k.transpose(-1, -2))
|
|
|
|
x = x.reshape(-1, W, rel_size)
|
|
|
|
|
|
|
|
# pad to shift from relative to absolute indexing
|
|
|
|
x_pad = F.pad(x, [0, 1]).flatten(1)
|
|
|
|
x_pad = F.pad(x_pad, [0, rel_size - W])
|
|
|
|
|
|
|
|
# reshape and slice out the padded elements
|
|
|
|
x_pad = x_pad.reshape(-1, W + 1, rel_size)
|
|
|
|
x = x_pad[:, :W, win_size - 1:]
|
|
|
|
|
|
|
|
# reshape and tile
|
|
|
|
x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1)
|
|
|
|
return x.permute(permute_mask)
|
|
|
|
|
|
|
|
|
|
|
|
class PosEmbedRel(nn.Module):
|
|
|
|
""" Relative Position Embedding
|
|
|
|
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
|
|
|
|
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
|
|
|
|
|
|
|
|
"""
|
|
|
|
def __init__(self, block_size, win_size, dim_head, scale):
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
block_size (int): block size
|
|
|
|
win_size (int): neighbourhood window size
|
|
|
|
dim_head (int): attention head dim
|
|
|
|
scale (float): scale factor (for init)
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
self.block_size = block_size
|
|
|
|
self.dim_head = dim_head
|
|
|
|
self.scale = scale
|
|
|
|
self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
|
|
|
|
self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
|
|
|
|
|
|
|
|
def forward(self, q):
|
|
|
|
B, BB, HW, _ = q.shape
|
|
|
|
|
|
|
|
# relative logits in width dimension.
|
|
|
|
q = q.reshape(-1, self.block_size, self.block_size, self.dim_head)
|
|
|
|
rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
|
|
|
|
|
|
|
|
# relative logits in height dimension.
|
|
|
|
q = q.transpose(1, 2)
|
|
|
|
rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
|
|
|
|
|
|
|
|
rel_logits = rel_logits_h + rel_logits_w
|
|
|
|
rel_logits = rel_logits.reshape(B, BB, HW, -1)
|
|
|
|
return rel_logits
|
|
|
|
|
|
|
|
|
|
|
|
class HaloAttn(nn.Module):
|
|
|
|
""" Halo Attention
|
|
|
|
|
|
|
|
Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
|
|
|
|
- https://arxiv.org/abs/2103.12731
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
|
|
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False):
|
|
|
|
super().__init__()
|
|
|
|
dim_out = dim_out or dim
|
|
|
|
assert dim_out % num_heads == 0
|
|
|
|
self.stride = stride
|
|
|
|
self.num_heads = num_heads
|
|
|
|
self.dim_head = dim_head or dim // num_heads
|
|
|
|
self.dim_qk = num_heads * self.dim_head
|
|
|
|
self.dim_v = dim_out
|
|
|
|
self.block_size = block_size
|
|
|
|
self.halo_size = halo_size
|
|
|
|
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
|
|
|
self.scale = self.dim_head ** -0.5
|
|
|
|
# stride_tricks hard-coded for now, works well on CPU / GPU, neither unfold or as_strided works on TPU (XLA)
|
|
|
|
self.stride_tricks = True
|
|
|
|
|
|
|
|
# FIXME not clear if this stride behaviour is what the paper intended
|
|
|
|
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
|
|
|
# data in unfolded block form. I haven't wrapped my head around how that'd look.
|
|
|
|
self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias)
|
|
|
|
self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias)
|
|
|
|
|
|
|
|
self.pos_embed = PosEmbedRel(
|
|
|
|
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
|
|
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
|
std = self.q.weight.shape[1] ** -0.5 # fan-in
|
|
|
|
trunc_normal_(self.q.weight, std=std)
|
|
|
|
trunc_normal_(self.kv.weight, std=std)
|
|
|
|
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
|
|
|
|
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
assert H % self.block_size == 0 and W % self.block_size == 0
|
|
|
|
num_h_blocks = H // self.block_size
|
|
|
|
num_w_blocks = W // self.block_size
|
|
|
|
num_blocks = num_h_blocks * num_w_blocks
|
|
|
|
bs_stride = self.block_size // self.stride
|
|
|
|
|
|
|
|
q = self.q(x)
|
|
|
|
# q = F.unfold(q, kernel_size=bs_stride, stride=bs_stride) # don't need to use unfold here since no overlap
|
|
|
|
q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
|
|
|
|
# B, num_heads * dim_head * block_size ** 2, num_blocks
|
|
|
|
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
|
|
|
|
# B * num_heads, num_blocks, block_size ** 2, dim_head
|
|
|
|
|
|
|
|
kv = self.kv(x)
|
|
|
|
|
|
|
|
# generate overlapping windows using either stride tricks (as_strided) or unfold
|
|
|
|
if self.stride_tricks:
|
|
|
|
# this is much faster
|
|
|
|
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
|
|
|
|
kv = kv.as_strided((
|
|
|
|
B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
|
|
|
|
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
|
|
|
|
else:
|
|
|
|
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
|
|
|
|
|
|
|
kv = kv.reshape(
|
|
|
|
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
|
|
|
|
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
|
|
|
|
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
|
|
|
|
|
|
|
|
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
|
|
|
|
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
|
|
|
|
|
|
|
|
attn_out = attn_logits.softmax(dim=-1)
|
|
|
|
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
|
|
|
|
|
|
|
|
# F.fold can be replaced by reshape + permute, slightly faster
|
|
|
|
# attn_out = F.fold(
|
|
|
|
# attn_out.reshape(B, -1, num_blocks),
|
|
|
|
# (H // self.stride, W // self.stride), kernel_size=bs_stride, stride=bs_stride)
|
|
|
|
attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
|
|
|
|
attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride)
|
|
|
|
# B, dim_out, H // stride, W // stride
|
|
|
|
return attn_out
|