diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 6304ae0d..173d2060 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -12,10 +12,7 @@ Year = {2021}, Status: This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me. - -Trying to match the 'H1' variant in the paper, my parameter counts are 2M less and the model -is extremely slow. Something isn't right. However, the models do appear to train and experimental -variants with attn in C4 and/or C5 stages are tolerable speed. +The attention mechanism works but it's slow as implemented. Hacked together by / Copyright 2021 Ross Wightman """ @@ -163,7 +160,6 @@ class HaloAttn(nn.Module): # kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) # kv = kv.reshape( # B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) - k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1) # B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads