Add option to include relative pos embedding in the attention scaling as per references. See discussion #912

pull/914/head
Ross Wightman 3 years ago
parent 2c33ca6d8c
commit 02daf2ab94

@ -61,9 +61,8 @@ class PosEmbedRel(nn.Module):
super().__init__()
self.height, self.width = to_2tuple(feat_size)
self.dim_head = dim_head
self.scale = scale
self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale)
self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale)
self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale)
self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale)
def forward(self, q):
B, HW, _ = q.shape
@ -101,10 +100,11 @@ class BottleneckAttn(nn.Module):
dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
qkv_bias (bool): add bias to q, k, and v projections
scale_pos_embed (bool): scale the position embedding as well as Q @ K
"""
def __init__(
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None,
qk_ratio=1.0, qkv_bias=False):
qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False):
super().__init__()
assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required'
dim_out = dim_out or dim
@ -115,6 +115,7 @@ class BottleneckAttn(nn.Module):
self.dim_out_qk = num_heads * self.dim_head_qk
self.dim_out_v = num_heads * self.dim_head_v
self.scale = self.dim_head_qk ** -0.5
self.scale_pos_embed = scale_pos_embed
self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
@ -144,8 +145,10 @@ class BottleneckAttn(nn.Module):
k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
attn = (q @ k) * self.scale
attn = attn + self.pos_embed(q) # B * num_heads, H * W, H * W
if self.scale_pos_embed:
attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
else:
attn = (q @ k) * self.scale + self.pos_embed(q)
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W

@ -74,9 +74,8 @@ class PosEmbedRel(nn.Module):
super().__init__()
self.block_size = block_size
self.dim_head = dim_head
self.scale = scale
self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
def forward(self, q):
B, BB, HW, _ = q.shape
@ -120,11 +119,11 @@ class HaloAttn(nn.Module):
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
qkv_bias (bool) : add bias to q, k, and v projections
avg_down (bool): use average pool downsample instead of strided query blocks
scale_pos_embed (bool): scale the position embedding as well as Q @ K
"""
def __init__(
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3,
qk_ratio=1.0, qkv_bias=False, avg_down=False):
qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False):
super().__init__()
dim_out = dim_out or dim
assert dim_out % num_heads == 0
@ -135,6 +134,7 @@ class HaloAttn(nn.Module):
self.dim_out_qk = num_heads * self.dim_head_qk
self.dim_out_v = num_heads * self.dim_head_v
self.scale = self.dim_head_qk ** -0.5
self.scale_pos_embed = scale_pos_embed
self.block_size = self.block_size_ds = block_size
self.halo_size = halo_size
self.win_size = block_size + halo_size * 2 # neighbourhood window size
@ -190,8 +190,11 @@ class HaloAttn(nn.Module):
k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
# B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v
attn = (q @ k.transpose(-1, -2)) * self.scale
attn = attn + self.pos_embed(q) # B * num_heads, num_blocks, block_size ** 2, win_size ** 2
if self.scale_pos_embed:
attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale
else:
attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q)
# B * num_heads, num_blocks, block_size ** 2, win_size ** 2
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks

Loading…
Cancel
Save