|
|
|
@ -183,7 +183,9 @@ class HaloAttn(nn.Module):
|
|
|
|
|
# B * num_heads, num_blocks, block_size ** 2, dim_head
|
|
|
|
|
|
|
|
|
|
kv = self.kv(x)
|
|
|
|
|
# generate overlapping windows for kv
|
|
|
|
|
# Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not
|
|
|
|
|
# lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach.
|
|
|
|
|
# FIXME figure out how to switch impl between this and conv2d if XLA being used.
|
|
|
|
|
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
|
|
|
|
|
kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
|
|
|
|
|
B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)
|
|
|
|
@ -207,17 +209,24 @@ class HaloAttn(nn.Module):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Two alternatives for overlapping windows.
|
|
|
|
|
""" Three alternatives for overlapping windows.
|
|
|
|
|
|
|
|
|
|
`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold()
|
|
|
|
|
|
|
|
|
|
if self.stride_tricks:
|
|
|
|
|
if is_xla:
|
|
|
|
|
# This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is
|
|
|
|
|
# EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment.
|
|
|
|
|
WW = self.win_size ** 2
|
|
|
|
|
pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size)
|
|
|
|
|
kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size)
|
|
|
|
|
elif self.stride_tricks:
|
|
|
|
|
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
|
|
|
|
|
kv = kv.as_strided((
|
|
|
|
|
B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
|
|
|
|
|
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
|
|
|
|
|
else:
|
|
|
|
|
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
|
|
|
|
|
|
|
|
|
kv = kv.reshape(
|
|
|
|
|
B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3)
|
|
|
|
|
"""
|
|
|
|
|