""" Attention Pool 2D Implementations of 2D spatial feature pooling using multi-head attention instead of average pool. Based on idea in CLIP by OpenAI, licensed Apache 2.0 https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py Hacked together by / Copyright 2021 Ross Wightman """ import math from typing import List, Union, Tuple import torch import torch.nn as nn from .helpers import to_2tuple from .weight_init import trunc_normal_ def rot(x): return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): return x * cos_emb + rot(x) * sin_emb def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): if isinstance(x, torch.Tensor): x = [x] return [t * cos_emb + rot(t) * sin_emb for t in x] class RotaryEmbedding(nn.Module): """ Rotary position embedding NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not been well tested, and will likely change. It will be moved to its own file. The following impl/resources were referenced for this impl: * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py * https://blog.eleuther.ai/rotary-embeddings/ """ def __init__(self, dim, max_freq=4): super().__init__() self.dim = dim self.register_buffer('bands', 2 ** torch.linspace(0., max_freq - 1, self.dim // 4), persistent=False) def get_embed(self, shape: torch.Size, device: torch.device = None, dtype: torch.dtype = None): """ NOTE: shape arg should include spatial dim only """ device = device or self.bands.device dtype = dtype or self.bands.dtype if not isinstance(shape, torch.Size): shape = torch.Size(shape) N = shape.numel() grid = torch.stack(torch.meshgrid( [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in shape]), dim=-1).unsqueeze(-1) emb = grid * math.pi * self.bands sin = emb.sin().reshape(N, -1).repeat_interleave(2, -1) cos = emb.cos().reshape(N, -1).repeat_interleave(2, -1) return sin, cos def forward(self, x): # assuming channel-first tensor where spatial dim are >= 2 sin_emb, cos_emb = self.get_embed(x.shape[2:]) return apply_rot_embed(x, sin_emb, cos_emb) class RotAttentionPool2d(nn.Module): """ Attention based 2D feature pooling w/ rotary (relative) pos embedding. This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW """ def __init__( self, in_features: int, out_features: int = None, embed_dim: int = None, num_heads: int = 4, qkv_bias: bool = True, ): super().__init__() embed_dim = embed_dim or in_features out_features = out_features or in_features self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) self.proj = nn.Linear(embed_dim, out_features) self.num_heads = num_heads assert embed_dim % num_heads == 0 self.head_dim = embed_dim // num_heads self.scale = self.head_dim ** -0.5 self.pos_embed = RotaryEmbedding(self.head_dim) trunc_normal_(self.qkv.weight, std=in_features ** -0.5) nn.init.zeros_(self.qkv.bias) def forward(self, x): B, _, H, W = x.shape N = H * W sin_emb, cos_emb = self.pos_embed.get_embed(x.shape[2:]) x = x.reshape(B, -1, N).permute(0, 2, 1) x = torch.cat([x.mean(1, keepdim=True), x], dim=1) x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = x[0], x[1], x[2] qc, q = q[:, :, :1], q[:, :, 1:] q = apply_rot_embed(q, sin_emb, cos_emb) q = torch.cat([qc, q], dim=2) kc, k = k[:, :, :1], k[:, :, 1:] k = apply_rot_embed(k, sin_emb, cos_emb) k = torch.cat([kc, k], dim=2) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) x = self.proj(x) return x[:, 0] class AttentionPool2d(nn.Module): """ Attention based 2D feature pooling w/ learned (absolute) pos embedding. This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. It was based on impl in CLIP by OpenAI https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. """ def __init__( self, in_features: int, feat_size: Union[int, Tuple[int, int]], out_features: int = None, embed_dim: int = None, num_heads: int = 4, qkv_bias: bool = True, ): super().__init__() embed_dim = embed_dim or in_features out_features = out_features or in_features assert embed_dim % num_heads == 0 self.feat_size = to_2tuple(feat_size) self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) self.proj = nn.Linear(embed_dim, out_features) self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.scale = self.head_dim ** -0.5 spatial_dim = self.feat_size[0] * self.feat_size[1] self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) trunc_normal_(self.pos_embed, std=in_features ** -0.5) trunc_normal_(self.qkv.weight, std=in_features ** -0.5) nn.init.zeros_(self.qkv.bias) def forward(self, x): B, _, H, W = x.shape N = H * W assert self.feat_size[0] == H assert self.feat_size[1] == W x = x.reshape(B, -1, N).permute(0, 2, 1) x = torch.cat([x.mean(1, keepdim=True), x], dim=1) x = x + self.pos_embed.unsqueeze(0).to(x.dtype) x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = x[0], x[1], x[2] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) x = self.proj(x) return x[:, 0]