|
|
@ -74,9 +74,8 @@ class PosEmbedRel(nn.Module):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
self.block_size = block_size
|
|
|
|
self.block_size = block_size
|
|
|
|
self.dim_head = dim_head
|
|
|
|
self.dim_head = dim_head
|
|
|
|
self.scale = scale
|
|
|
|
self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
|
|
|
|
self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
|
|
|
|
self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
|
|
|
|
self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, q):
|
|
|
|
def forward(self, q):
|
|
|
|
B, BB, HW, _ = q.shape
|
|
|
|
B, BB, HW, _ = q.shape
|
|
|
@ -120,11 +119,11 @@ class HaloAttn(nn.Module):
|
|
|
|
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
|
|
|
|
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
|
|
|
|
qkv_bias (bool) : add bias to q, k, and v projections
|
|
|
|
qkv_bias (bool) : add bias to q, k, and v projections
|
|
|
|
avg_down (bool): use average pool downsample instead of strided query blocks
|
|
|
|
avg_down (bool): use average pool downsample instead of strided query blocks
|
|
|
|
|
|
|
|
scale_pos_embed (bool): scale the position embedding as well as Q @ K
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3,
|
|
|
|
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3,
|
|
|
|
qk_ratio=1.0, qkv_bias=False, avg_down=False):
|
|
|
|
qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
dim_out = dim_out or dim
|
|
|
|
dim_out = dim_out or dim
|
|
|
|
assert dim_out % num_heads == 0
|
|
|
|
assert dim_out % num_heads == 0
|
|
|
@ -135,6 +134,7 @@ class HaloAttn(nn.Module):
|
|
|
|
self.dim_out_qk = num_heads * self.dim_head_qk
|
|
|
|
self.dim_out_qk = num_heads * self.dim_head_qk
|
|
|
|
self.dim_out_v = num_heads * self.dim_head_v
|
|
|
|
self.dim_out_v = num_heads * self.dim_head_v
|
|
|
|
self.scale = self.dim_head_qk ** -0.5
|
|
|
|
self.scale = self.dim_head_qk ** -0.5
|
|
|
|
|
|
|
|
self.scale_pos_embed = scale_pos_embed
|
|
|
|
self.block_size = self.block_size_ds = block_size
|
|
|
|
self.block_size = self.block_size_ds = block_size
|
|
|
|
self.halo_size = halo_size
|
|
|
|
self.halo_size = halo_size
|
|
|
|
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
|
|
|
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
|
|
@ -190,8 +190,11 @@ class HaloAttn(nn.Module):
|
|
|
|
k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
|
|
|
|
k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
|
|
|
|
# B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v
|
|
|
|
# B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v
|
|
|
|
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-1, -2)) * self.scale
|
|
|
|
if self.scale_pos_embed:
|
|
|
|
attn = attn + self.pos_embed(q) # B * num_heads, num_blocks, block_size ** 2, win_size ** 2
|
|
|
|
attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q)
|
|
|
|
|
|
|
|
# B * num_heads, num_blocks, block_size ** 2, win_size ** 2
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks
|
|
|
|
out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks
|
|
|
|