diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 056813ef..f58b724c 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -3,7 +3,7 @@ A flexible network w/ dataclass based config for stacking NN blocks including self-attention (or similar) layers. -Currently used to implement experimential variants of: +Currently used to implement experimental variants of: * Bottleneck Transformers * Lambda ResNets * HaloNets @@ -46,15 +46,16 @@ default_cfgs = { 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet26t': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth', - input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'sehalonet33ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'halonet50ts': _cfg( - url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_256_ra3-f07eab9f.pth', + input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'eca_halonext26ts': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth', - input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94), 'lambda_resnet26t': _cfg( url='', diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index bf6af675..61859f9c 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -118,12 +118,12 @@ class BottleneckAttn(nn.Module): x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2) q, k, v = torch.split(x, self.num_heads, dim=1) - attn_logits = (q @ k.transpose(-1, -2)) * self.scale - attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W + attn = (q @ k.transpose(-1, -2)) * self.scale + attn = attn + self.pos_embed(q) # B, num_heads, H * W, H * W + attn = attn.softmax(dim=-1) - attn_out = attn_logits.softmax(dim=-1) - attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W - attn_out = self.pool(attn_out) - return attn_out + out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W + out = self.pool(out) + return out diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index d298fc0b..034c66a8 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -106,22 +106,23 @@ class HaloAttn(nn.Module): assert dim_out % num_heads == 0 self.stride = stride self.num_heads = num_heads - self.dim_head = dim_head or dim // num_heads - self.dim_qk = num_heads * self.dim_head - self.dim_v = dim_out + self.dim_head_qk = dim_head or dim_out // num_heads + self.dim_head_v = dim_out // self.num_heads + self.dim_out_qk = num_heads * self.dim_head_qk + self.dim_out_v = num_heads * self.dim_head_v self.block_size = block_size self.halo_size = halo_size self.win_size = block_size + halo_size * 2 # neighbourhood window size - self.scale = self.dim_head ** -0.5 + self.scale = self.dim_head_qk ** -0.5 # FIXME not clear if this stride behaviour is what the paper intended # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving # data in unfolded block form. I haven't wrapped my head around how that'd look. - self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias) - self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias) + self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.stride, bias=qkv_bias) + self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias) self.pos_embed = PosEmbedRel( - block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale) + block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) self.reset_parameters() @@ -143,37 +144,42 @@ class HaloAttn(nn.Module): q = self.q(x) # unfold - q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4) + q = q.reshape(-1, self.dim_head_qk, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4) # B, num_heads * dim_head * block_size ** 2, num_blocks - q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3) + q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3) # B * num_heads, num_blocks, block_size ** 2, dim_head kv = self.kv(x) # generate overlapping windows for kv kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]) kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape( - B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), num_blocks, -1).permute(0, 2, 3, 1) - # NOTE these two alternatives are equivalent, but above is the best balance of performance and clarity - # if self.stride_tricks: - # kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() - # kv = kv.as_strided(( - # B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), - # stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) - # else: - # kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) - # kv = kv.reshape( - # B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3) - k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1) - # B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads - - attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied? - attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2 - - attn_out = attn_logits.softmax(dim=-1) - attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks + B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1) + k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1) + # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v + attn = (q @ k.transpose(-1, -2)) * self.scale + attn = attn + self.pos_embed(q) # B * num_heads, num_blocks, block_size ** 2, win_size ** 2 + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks # fold - attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks) - attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride) + out = out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks) + out = out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_out_v, H // self.stride, W // self.stride) # B, dim_out, H // stride, W // stride - return attn_out + return out + + +""" Two alternatives for overlapping windows. + +`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold() + + if self.stride_tricks: + kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() + kv = kv.as_strided(( + B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), + stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) + else: + kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + kv = kv.reshape( + B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3) +"""