Update HaloAttn comment

pull/821/head
Ross Wightman 3 years ago
parent 3b9032ea48
commit 492c0a4e20

@ -12,10 +12,7 @@ Year = {2021},
Status:
This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me.
Trying to match the 'H1' variant in the paper, my parameter counts are 2M less and the model
is extremely slow. Something isn't right. However, the models do appear to train and experimental
variants with attn in C4 and/or C5 stages are tolerable speed.
The attention mechanism works but it's slow as implemented.
Hacked together by / Copyright 2021 Ross Wightman
"""
@ -163,7 +160,6 @@ class HaloAttn(nn.Module):
# kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
# kv = kv.reshape(
# B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads

Loading…
Cancel
Save