@ -24,7 +24,7 @@ import torch
from torch import nn
import torch . nn . functional as F
from . helpers import to_2tuple
from . helpers import to_2tuple , make_divisible
from . weight_init import trunc_normal_
@ -44,28 +44,46 @@ class LambdaLayer(nn.Module):
- https : / / arxiv . org / abs / 2102.08602
NOTE : intra - depth parameter ' u ' is fixed at 1. It did not appear worth the complexity to add .
The internal dimensions of the lambda module are controlled via the interaction of several arguments .
* the output dimension of the module is specified by dim_out , which falls back to input dim if not set
* the value ( v ) dimension is set to dim_out / / num_heads , the v projection determines the output dim
* the query ( q ) and key ( k ) dimension are determined by
* dim_head = ( dim_out * attn_ratio / / num_heads ) if dim_head is None
* q = num_heads * dim_head , k = dim_head
* as seen above , attn_ratio determines the ratio of q and k relative to the output if dim_head not set
Args :
dim ( int ) : input dimension to the module
dim_out ( int ) : output dimension of the module , same as dim if not set
feat_size ( Tuple [ int , int ] ) : size of input feature_map for relative pos variant H , W
stride ( int ) : output stride of the module , avg pool used if stride == 2
num_heads ( int ) : parallel attention heads .
dim_head ( int ) : dimension of query and key heads , calculated from dim_out * attn_ratio / / num_heads if not set
r ( int ) : local lambda convolution radius . Use lambda conv if set , else relative pos if not . ( default : 9 )
qk_ratio ( float ) : ratio of q and k dimensions to output dimension when dim_head not set . ( default : 1.0 )
qkv_bias ( bool ) : add bias to q , k , and v projections
"""
def __init__ (
self ,
dim , dim_out = None , feat_size = None , stride = 1 , num_heads = 4 , dim_head = 16 , r = 7 , qkv_bias = False ) :
self , dim , dim_out = None , feat_size = None , stride = 1 , num_heads = 4 , dim_head = 16 , r = 9 ,
qk_ratio= 1.0 , qkv_bias = False ) :
super ( ) . __init__ ( )
self . dim = dim
self . dim_out = dim_out or dim
self . dim_k = dim_head # query depth 'k'
dim_out = dim_out or dim
assert dim_out % num_heads == 0 , ' should be divided by num_heads '
self . dim_ qk = dim_head or make_divisible ( dim_out * qk_ratio , divisor = 8 ) / / num_heads
self . num_heads = num_heads
assert self . dim_out % num_heads == 0 , ' should be divided by num_heads '
self . dim_v = self . dim_out / / num_heads # value depth 'v'
self . dim_v = dim_out / / num_heads
self . qkv = nn . Conv2d (
dim ,
num_heads * dim_head + dim_head + self . dim_v ,
num_heads * self . dim_qk + self . dim_qk + self . dim_v ,
kernel_size = 1 , bias = qkv_bias )
self . norm_q = nn . BatchNorm2d ( num_heads * dim_head )
self . norm_q = nn . BatchNorm2d ( num_heads * self . dim_qk )
self . norm_v = nn . BatchNorm2d ( self . dim_v )
if r is not None :
# local lambda convolution for pos
self . conv_lambda = nn . Conv3d ( 1 , dim_head , ( r , r , 1 ) , padding = ( r / / 2 , r / / 2 , 0 ) )
self . conv_lambda = nn . Conv3d ( 1 , self . dim_qk , ( r , r , 1 ) , padding = ( r / / 2 , r / / 2 , 0 ) )
self . pos_emb = None
self . rel_pos_indices = None
else :
@ -74,7 +92,7 @@ class LambdaLayer(nn.Module):
feat_size = to_2tuple ( feat_size )
rel_size = [ 2 * s - 1 for s in feat_size ]
self . conv_lambda = None
self . pos_emb = nn . Parameter ( torch . zeros ( rel_size [ 0 ] , rel_size [ 1 ] , self . dim_ k) )
self . pos_emb = nn . Parameter ( torch . zeros ( rel_size [ 0 ] , rel_size [ 1 ] , self . dim_ q k) )
self . register_buffer ( ' rel_pos_indices ' , rel_pos_indices ( feat_size ) , persistent = False )
self . pool = nn . AvgPool2d ( 2 , 2 ) if stride == 2 else nn . Identity ( )
@ -82,9 +100,9 @@ class LambdaLayer(nn.Module):
self . reset_parameters ( )
def reset_parameters ( self ) :
trunc_normal_ ( self . qkv . weight , std = self . dim * * - 0.5 )
trunc_normal_ ( self . qkv . weight , std = self . qkv. weight . shape [ 1 ] * * - 0.5 ) # fan-in
if self . conv_lambda is not None :
trunc_normal_ ( self . conv_lambda . weight , std = self . dim_ k * * - 0.5 )
trunc_normal_ ( self . conv_lambda . weight , std = self . dim_ q k * * - 0.5 )
if self . pos_emb is not None :
trunc_normal_ ( self . pos_emb , std = .02 )
@ -93,17 +111,17 @@ class LambdaLayer(nn.Module):
M = H * W
qkv = self . qkv ( x )
q , k , v = torch . split ( qkv , [
self . num_heads * self . dim_ k, self . dim_ k, self . dim_v ] , dim = 1 )
q = self . norm_q ( q ) . reshape ( B , self . num_heads , self . dim_ k, M ) . transpose ( - 1 , - 2 ) # B, num_heads, M, K
self . num_heads * self . dim_ q k, self . dim_ q k, self . dim_v ] , dim = 1 )
q = self . norm_q ( q ) . reshape ( B , self . num_heads , self . dim_ q k, M ) . transpose ( - 1 , - 2 ) # B, num_heads, M, K
v = self . norm_v ( v ) . reshape ( B , self . dim_v , M ) . transpose ( - 1 , - 2 ) # B, M, V
k = F . softmax ( k . reshape ( B , self . dim_ k, M ) , dim = - 1 ) # B, K, M
k = F . softmax ( k . reshape ( B , self . dim_ q k, M ) , dim = - 1 ) # B, K, M
content_lam = k @ v # B, K, V
content_out = q @ content_lam . unsqueeze ( 1 ) # B, num_heads, M, V
if self . pos_emb is None :
position_lam = self . conv_lambda ( v . reshape ( B , 1 , H , W , self . dim_v ) ) # B, H, W, V, K
position_lam = position_lam . reshape ( B , 1 , self . dim_ k, H * W , self . dim_v ) . transpose ( 2 , 3 ) # B, 1, M, K, V
position_lam = position_lam . reshape ( B , 1 , self . dim_ q k, H * W , self . dim_v ) . transpose ( 2 , 3 ) # B, 1, M, K, V
else :
# FIXME relative pos embedding path not fully verified
pos_emb = self . pos_emb [ self . rel_pos_indices [ 0 ] , self . rel_pos_indices [ 1 ] ] . expand ( B , - 1 , - 1 , - 1 )