You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/layers/lambda_layer.py

134 lines
5.8 KiB

""" Lambda Layer
Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
- https://arxiv.org/abs/2102.08602
@misc{2102.08602,
Author = {Irwan Bello},
Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention},
Year = {2021},
}
Status:
This impl is a WIP. Code snippets in the paper were used as reference but
good chance some details are missing/wrong.
I've only implemented local lambda conv based pos embeddings.
For a PyTorch impl that includes other embedding options checkout
https://github.com/lucidrains/lambda-networks
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
from torch import nn
import torch.nn.functional as F
from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_
def rel_pos_indices(size):
size = to_2tuple(size)
pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
rel_pos = pos[:, None, :] - pos[:, :, None]
rel_pos[0] += size[0] - 1
rel_pos[1] += size[1] - 1
return rel_pos # 2, H * W, H * W
class LambdaLayer(nn.Module):
"""Lambda Layer
Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
- https://arxiv.org/abs/2102.08602
NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
The internal dimensions of the lambda module are controlled via the interaction of several arguments.
* the output dimension of the module is specified by dim_out, which falls back to input dim if not set
* the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
* the query (q) and key (k) dimension are determined by
* dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None
* q = num_heads * dim_head, k = dim_head
* as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set
Args:
dim (int): input dimension to the module
dim_out (int): output dimension of the module, same as dim if not set
feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W
stride (int): output stride of the module, avg pool used if stride == 2
num_heads (int): parallel attention heads.
dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9)
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
qkv_bias (bool): add bias to q, k, and v projections
"""
def __init__(
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9,
qk_ratio=1.0, qkv_bias=False):
super().__init__()
dim_out = dim_out or dim
assert dim_out % num_heads == 0, ' should be divided by num_heads'
self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
self.num_heads = num_heads
self.dim_v = dim_out // num_heads
self.qkv = nn.Conv2d(
dim,
num_heads * self.dim_qk + self.dim_qk + self.dim_v,
kernel_size=1, bias=qkv_bias)
self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk)
self.norm_v = nn.BatchNorm2d(self.dim_v)
if r is not None:
# local lambda convolution for pos
self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0))
self.pos_emb = None
self.rel_pos_indices = None
else:
# relative pos embedding
assert feat_size is not None
feat_size = to_2tuple(feat_size)
rel_size = [2 * s - 1 for s in feat_size]
self.conv_lambda = None
self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk))
self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
self.reset_parameters()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
if self.conv_lambda is not None:
trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5)
if self.pos_emb is not None:
trunc_normal_(self.pos_emb, std=.02)
def forward(self, x):
B, C, H, W = x.shape
M = H * W
qkv = self.qkv(x)
q, k, v = torch.split(qkv, [
self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1)
q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K
v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
content_lam = k @ v # B, K, V
content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
if self.pos_emb is None:
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
else:
# FIXME relative pos embedding path not fully verified
pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)
position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V
position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V
out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W
out = self.pool(out)
return out