|
|
|
@ -12,10 +12,7 @@ Year = {2021},
|
|
|
|
|
|
|
|
|
|
Status:
|
|
|
|
|
This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me.
|
|
|
|
|
|
|
|
|
|
Trying to match the 'H1' variant in the paper, my parameter counts are 2M less and the model
|
|
|
|
|
is extremely slow. Something isn't right. However, the models do appear to train and experimental
|
|
|
|
|
variants with attn in C4 and/or C5 stages are tolerable speed.
|
|
|
|
|
The attention mechanism works but it's slow as implemented.
|
|
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2021 Ross Wightman
|
|
|
|
|
"""
|
|
|
|
@ -163,7 +160,6 @@ class HaloAttn(nn.Module):
|
|
|
|
|
# kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
|
|
|
|
|
# kv = kv.reshape(
|
|
|
|
|
# B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
|
|
|
|
|
|
|
|
|
|
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
|
|
|
|
|
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
|
|
|
|
|
|
|
|
|
|