diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index 15df62ae..f55fd989 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -61,9 +61,8 @@ class PosEmbedRel(nn.Module): super().__init__() self.height, self.width = to_2tuple(feat_size) self.dim_head = dim_head - self.scale = scale - self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale) - self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale) + self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale) + self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale) def forward(self, q): B, HW, _ = q.shape @@ -101,10 +100,11 @@ class BottleneckAttn(nn.Module): dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) qkv_bias (bool): add bias to q, k, and v projections + scale_pos_embed (bool): scale the position embedding as well as Q @ K """ def __init__( self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None, - qk_ratio=1.0, qkv_bias=False): + qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False): super().__init__() assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' dim_out = dim_out or dim @@ -115,6 +115,7 @@ class BottleneckAttn(nn.Module): self.dim_out_qk = num_heads * self.dim_head_qk self.dim_out_v = num_heads * self.dim_head_v self.scale = self.dim_head_qk ** -0.5 + self.scale_pos_embed = scale_pos_embed self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias) @@ -144,8 +145,10 @@ class BottleneckAttn(nn.Module): k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2) - attn = (q @ k) * self.scale - attn = attn + self.pos_embed(q) # B * num_heads, H * W, H * W + if self.scale_pos_embed: + attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W + else: + attn = (q @ k) * self.scale + self.pos_embed(q) attn = attn.softmax(dim=-1) out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 05fb1f6a..846c12ff 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -74,9 +74,8 @@ class PosEmbedRel(nn.Module): super().__init__() self.block_size = block_size self.dim_head = dim_head - self.scale = scale - self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale) - self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * self.scale) + self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) + self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale) def forward(self, q): B, BB, HW, _ = q.shape @@ -120,11 +119,11 @@ class HaloAttn(nn.Module): qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0) qkv_bias (bool) : add bias to q, k, and v projections avg_down (bool): use average pool downsample instead of strided query blocks - + scale_pos_embed (bool): scale the position embedding as well as Q @ K """ def __init__( self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, - qk_ratio=1.0, qkv_bias=False, avg_down=False): + qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False): super().__init__() dim_out = dim_out or dim assert dim_out % num_heads == 0 @@ -135,6 +134,7 @@ class HaloAttn(nn.Module): self.dim_out_qk = num_heads * self.dim_head_qk self.dim_out_v = num_heads * self.dim_head_v self.scale = self.dim_head_qk ** -0.5 + self.scale_pos_embed = scale_pos_embed self.block_size = self.block_size_ds = block_size self.halo_size = halo_size self.win_size = block_size + halo_size * 2 # neighbourhood window size @@ -190,8 +190,11 @@ class HaloAttn(nn.Module): k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1) # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v - attn = (q @ k.transpose(-1, -2)) * self.scale - attn = attn + self.pos_embed(q) # B * num_heads, num_blocks, block_size ** 2, win_size ** 2 + if self.scale_pos_embed: + attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale + else: + attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q) + # B * num_heads, num_blocks, block_size ** 2, win_size ** 2 attn = attn.softmax(dim=-1) out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks