You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/layers/halo_attn.py

184 lines
7.7 KiB

""" Halo Self Attention
Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
- https://arxiv.org/abs/2103.12731
@misc{2103.12731,
Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and
Jonathon Shlens},
Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones},
Year = {2021},
}
Status:
This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me.
Trying to match the 'H1' variant in the paper, my parameter counts are 2M less and the model
is extremely slow. Something isn't right. However, the models do appear to train and experimental
variants with attn in C4 and/or C5 stages are tolerable speed.
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import Tuple, List
import torch
from torch import nn
import torch.nn.functional as F
from .weight_init import trunc_normal_
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
""" Compute relative logits along one dimension
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
Args:
q: (batch, height, width, dim)
rel_k: (2 * window - 1, dim)
permute_mask: permute output dim according to this
"""
B, H, W, dim = q.shape
rel_size = rel_k.shape[0]
win_size = (rel_size + 1) // 2
x = (q @ rel_k.transpose(-1, -2))
x = x.reshape(-1, W, rel_size)
# pad to shift from relative to absolute indexing
x_pad = F.pad(x, [0, 1]).flatten(1)
x_pad = F.pad(x_pad, [0, rel_size - W])
# reshape and slice out the padded elements
x_pad = x_pad.reshape(-1, W + 1, rel_size)
x = x_pad[:, :W, win_size - 1:]
# reshape and tile
x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1)
return x.permute(permute_mask)
class PosEmbedRel(nn.Module):
""" Relative Position Embedding
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
"""
def __init__(self, block_size, win_size, dim_head, scale):
"""
Args:
block_size (int): block size
win_size (int): neighbourhood window size
dim_head (int): attention head dim
scale (float): scale factor (for init)
"""
super().__init__()
self.block_size = block_size
self.dim_head = dim_head
self.scale = scale
self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
def forward(self, q):
B, BB, HW, _ = q.shape
# relative logits in width dimension.
q = q.reshape(-1, self.block_size, self.block_size, self.dim_head)
rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
# relative logits in height dimension.
q = q.transpose(1, 2)
rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
rel_logits = rel_logits_h + rel_logits_w
rel_logits = rel_logits.reshape(B, BB, HW, -1)
return rel_logits
class HaloAttn(nn.Module):
""" Halo Attention
Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
- https://arxiv.org/abs/2103.12731
"""
def __init__(
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False):
super().__init__()
dim_out = dim_out or dim
assert dim_out % num_heads == 0
self.stride = stride
self.num_heads = num_heads
self.dim_head = dim_head or dim // num_heads
self.dim_qk = num_heads * self.dim_head
self.dim_v = dim_out
self.block_size = block_size
self.halo_size = halo_size
self.win_size = block_size + halo_size * 2 # neighbourhood window size
self.scale = self.dim_head ** -0.5
# stride_tricks hard-coded for now, works well on CPU / GPU, neither unfold or as_strided works on TPU (XLA)
self.stride_tricks = True
# FIXME not clear if this stride behaviour is what the paper intended
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
# data in unfolded block form. I haven't wrapped my head around how that'd look.
self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias)
self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias)
self.pos_embed = PosEmbedRel(
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
def reset_parameters(self):
std = self.q.weight.shape[1] ** -0.5 # fan-in
trunc_normal_(self.q.weight, std=std)
trunc_normal_(self.kv.weight, std=std)
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
def forward(self, x):
B, C, H, W = x.shape
assert H % self.block_size == 0 and W % self.block_size == 0
num_h_blocks = H // self.block_size
num_w_blocks = W // self.block_size
num_blocks = num_h_blocks * num_w_blocks
bs_stride = self.block_size // self.stride
q = self.q(x)
# q = F.unfold(q, kernel_size=bs_stride, stride=bs_stride) # don't need to use unfold here since no overlap
q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
# B, num_heads * dim_head * block_size ** 2, num_blocks
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
# B * num_heads, num_blocks, block_size ** 2, dim_head
kv = self.kv(x)
# generate overlapping windows using either stride tricks (as_strided) or unfold
if self.stride_tricks:
# this is much faster
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
kv = kv.as_strided((
B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
else:
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
kv = kv.reshape(
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
attn_out = attn_logits.softmax(dim=-1)
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
# F.fold can be replaced by reshape + permute, slightly faster
# attn_out = F.fold(
# attn_out.reshape(B, -1, num_blocks),
# (H // self.stride, W // self.stride), kernel_size=bs_stride, stride=bs_stride)
attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride)
# B, dim_out, H // stride, W // stride
return attn_out