diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 044c5dad..6304ae0d 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -116,8 +116,6 @@ class HaloAttn(nn.Module): self.halo_size = halo_size self.win_size = block_size + halo_size * 2 # neighbourhood window size self.scale = self.dim_head ** -0.5 - # stride_tricks hard-coded for now, works well on CPU / GPU, neither unfold or as_strided works on TPU (XLA) - self.stride_tricks = True # FIXME not clear if this stride behaviour is what the paper intended # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving @@ -144,26 +142,28 @@ class HaloAttn(nn.Module): bs_stride = self.block_size // self.stride q = self.q(x) - # q = F.unfold(q, kernel_size=bs_stride, stride=bs_stride) # don't need to use unfold here since no overlap + # unfold q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4) # B, num_heads * dim_head * block_size ** 2, num_blocks q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3) # B * num_heads, num_blocks, block_size ** 2, dim_head kv = self.kv(x) + # generate overlapping windows for kv + kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]) + kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape( + B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), num_blocks, -1).permute(0, 2, 3, 1) + # NOTE these two alternatives are equivalent, but above is the best balance of performance and clarity + # if self.stride_tricks: + # kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() + # kv = kv.as_strided(( + # B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), + # stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) + # else: + # kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + # kv = kv.reshape( + # B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) - # generate overlapping windows using either stride tricks (as_strided) or unfold - if self.stride_tricks: - # this is much faster - kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() - kv = kv.as_strided(( - B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), - stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) - else: - kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) - - kv = kv.reshape( - B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1) # B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads @@ -173,10 +173,7 @@ class HaloAttn(nn.Module): attn_out = attn_logits.softmax(dim=-1) attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks - # F.fold can be replaced by reshape + permute, slightly faster - # attn_out = F.fold( - # attn_out.reshape(B, -1, num_blocks), - # (H // self.stride, W // self.stride), kernel_size=bs_stride, stride=bs_stride) + # fold attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks) attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride) # B, dim_out, H // stride, W // stride