|
|
@ -116,8 +116,6 @@ class HaloAttn(nn.Module):
|
|
|
|
self.halo_size = halo_size
|
|
|
|
self.halo_size = halo_size
|
|
|
|
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
|
|
|
self.win_size = block_size + halo_size * 2 # neighbourhood window size
|
|
|
|
self.scale = self.dim_head ** -0.5
|
|
|
|
self.scale = self.dim_head ** -0.5
|
|
|
|
# stride_tricks hard-coded for now, works well on CPU / GPU, neither unfold or as_strided works on TPU (XLA)
|
|
|
|
|
|
|
|
self.stride_tricks = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# FIXME not clear if this stride behaviour is what the paper intended
|
|
|
|
# FIXME not clear if this stride behaviour is what the paper intended
|
|
|
|
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
|
|
|
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
|
|
|
@ -144,26 +142,28 @@ class HaloAttn(nn.Module):
|
|
|
|
bs_stride = self.block_size // self.stride
|
|
|
|
bs_stride = self.block_size // self.stride
|
|
|
|
|
|
|
|
|
|
|
|
q = self.q(x)
|
|
|
|
q = self.q(x)
|
|
|
|
# q = F.unfold(q, kernel_size=bs_stride, stride=bs_stride) # don't need to use unfold here since no overlap
|
|
|
|
# unfold
|
|
|
|
q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
|
|
|
|
q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
|
|
|
|
# B, num_heads * dim_head * block_size ** 2, num_blocks
|
|
|
|
# B, num_heads * dim_head * block_size ** 2, num_blocks
|
|
|
|
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
|
|
|
|
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
|
|
|
|
# B * num_heads, num_blocks, block_size ** 2, dim_head
|
|
|
|
# B * num_heads, num_blocks, block_size ** 2, dim_head
|
|
|
|
|
|
|
|
|
|
|
|
kv = self.kv(x)
|
|
|
|
kv = self.kv(x)
|
|
|
|
|
|
|
|
# generate overlapping windows for kv
|
|
|
|
|
|
|
|
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
|
|
|
|
|
|
|
|
kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
|
|
|
|
|
|
|
|
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), num_blocks, -1).permute(0, 2, 3, 1)
|
|
|
|
|
|
|
|
# NOTE these two alternatives are equivalent, but above is the best balance of performance and clarity
|
|
|
|
|
|
|
|
# if self.stride_tricks:
|
|
|
|
|
|
|
|
# kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
|
|
|
|
|
|
|
|
# kv = kv.as_strided((
|
|
|
|
|
|
|
|
# B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
|
|
|
|
|
|
|
|
# stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
|
|
|
|
|
|
|
|
# else:
|
|
|
|
|
|
|
|
# kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
|
|
|
|
|
|
|
# kv = kv.reshape(
|
|
|
|
|
|
|
|
# B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
|
|
|
|
|
|
|
|
|
|
|
|
# generate overlapping windows using either stride tricks (as_strided) or unfold
|
|
|
|
|
|
|
|
if self.stride_tricks:
|
|
|
|
|
|
|
|
# this is much faster
|
|
|
|
|
|
|
|
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
|
|
|
|
|
|
|
|
kv = kv.as_strided((
|
|
|
|
|
|
|
|
B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
|
|
|
|
|
|
|
|
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kv = kv.reshape(
|
|
|
|
|
|
|
|
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
|
|
|
|
|
|
|
|
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
|
|
|
|
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
|
|
|
|
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
|
|
|
|
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
|
|
|
|
|
|
|
|
|
|
|
@ -173,10 +173,7 @@ class HaloAttn(nn.Module):
|
|
|
|
attn_out = attn_logits.softmax(dim=-1)
|
|
|
|
attn_out = attn_logits.softmax(dim=-1)
|
|
|
|
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
|
|
|
|
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
|
|
|
|
|
|
|
|
|
|
|
|
# F.fold can be replaced by reshape + permute, slightly faster
|
|
|
|
# fold
|
|
|
|
# attn_out = F.fold(
|
|
|
|
|
|
|
|
# attn_out.reshape(B, -1, num_blocks),
|
|
|
|
|
|
|
|
# (H // self.stride, W // self.stride), kernel_size=bs_stride, stride=bs_stride)
|
|
|
|
|
|
|
|
attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
|
|
|
|
attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
|
|
|
|
attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride)
|
|
|
|
attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride)
|
|
|
|
# B, dim_out, H // stride, W // stride
|
|
|
|
# B, dim_out, H // stride, W // stride
|
|
|
|