Halo, bottleneck attn, lambda layer additions and cleanup along w/ experimental model defs

* align interfaces of halo, bottleneck attn and lambda layer
* add qk_ratio to all of above, control q/k dim relative to output dim
* add experimental haloregnetz, and trionet (lambda + halo + bottle) models
pull/910/head
Ross Wightman 3 years ago
parent e0b3a3fab3
commit e2b8d44ff0

@ -66,6 +66,13 @@ default_cfgs = {
'lambda_resnet26rpt_256': _cfg( 'lambda_resnet26rpt_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_a2h_256-482adad8.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_a2h_256-482adad8.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'haloregnetz_b': _cfg(
url='',
input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
'trionet50ts_256': _cfg(
url='',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
} }
@ -232,6 +239,46 @@ model_cfgs = dict(
self_attn_layer='lambda', self_attn_layer='lambda',
self_attn_kwargs=dict(r=None) self_attn_kwargs=dict(r=None)
), ),
# experimental
haloregnetz_b=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
interleave_blocks(types=('bottle', 'self_attn'), every=3, d=12, c=192, s=2, gs=16, br=3),
ByoBlockCfg('self_attn', d=2, c=288, s=2, gs=16, br=3),
),
stem_chs=32,
stem_pool='',
downsample='',
num_features=1536,
act_layer='silu',
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=7, halo_size=2, qk_ratio=0.33)
),
# experimental
trionet50ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
interleave_blocks(
types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
self_attn_layer='lambda', self_attn_kwargs=dict(r=13)),
interleave_blocks(
types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
interleave_blocks(
types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
self_attn_layer='bottleneck', self_attn_kwargs=dict()),
),
stem_chs=64,
stem_type='tiered',
stem_pool='',
act_layer='silu',
),
) )
@ -327,3 +374,17 @@ def lambda_resnet26rpt_256(pretrained=False, **kwargs):
""" """
kwargs.setdefault('img_size', 256) kwargs.setdefault('img_size', 256)
return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs) return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs)
@register_model
def haloregnetz_b(pretrained=False, **kwargs):
""" Halo + RegNetZ
"""
return _create_byoanet('haloregnetz_b', pretrained=pretrained, **kwargs)
@register_model
def trionet50ts_256(pretrained=False, **kwargs):
""" HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages
"""
return _create_byoanet('trionet50ts_256', 'trionet50ts', pretrained=pretrained, **kwargs)

@ -1096,18 +1096,16 @@ class SelfAttnBlock(nn.Module):
self.self_attn.reset_parameters() self.self_attn.reset_parameters()
def forward(self, x): def forward(self, x):
shortcut = self.shortcut(x) shortcut = x
x = self.conv1_1x1(x) x = self.conv1_1x1(x)
x = self.conv2_kxk(x) x = self.conv2_kxk(x)
x = self.self_attn(x) x = self.self_attn(x)
x = self.post_attn(x) x = self.post_attn(x)
x = self.conv3_1x1(x) x = self.conv3_1x1(x)
x = self.drop_path(x) x = self.drop_path(x)
if self.shortcut is not None:
x = self.act(x + shortcut) x = x + self.shortcut(shortcut)
return x return self.act(x)
_block_registry = dict( _block_registry = dict(
basic=BasicBlock, basic=BasicBlock,

@ -20,7 +20,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import to_2tuple from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
@ -66,10 +66,10 @@ class PosEmbedRel(nn.Module):
self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale) self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale)
def forward(self, q): def forward(self, q):
B, num_heads, HW, _ = q.shape B, HW, _ = q.shape
# relative logits in width dimension. # relative logits in width dimension.
q = q.reshape(B * num_heads, self.height, self.width, -1) q = q.reshape(B, self.height, self.width, -1)
rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
# relative logits in height dimension. # relative logits in height dimension.
@ -77,35 +77,56 @@ class PosEmbedRel(nn.Module):
rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
rel_logits = rel_logits_h + rel_logits_w rel_logits = rel_logits_h + rel_logits_w
rel_logits = rel_logits.reshape(B, num_heads, HW, HW) rel_logits = rel_logits.reshape(B, HW, HW)
return rel_logits return rel_logits
class BottleneckAttn(nn.Module): class BottleneckAttn(nn.Module):
""" Bottleneck Attention """ Bottleneck Attention
Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
The internal dimensions of the attention module are controlled by the interaction of several arguments.
* the output dimension of the module is specified by dim_out, which falls back to input dim if not set
* the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
* the query and key (qk) dimensions are determined by
* num_heads * dim_head if dim_head is not None
* num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
* as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
Args:
dim (int): input dimension to the module
dim_out (int): output dimension of the module, same as dim if not set
stride (int): output stride of the module, avg pool used if stride == 2 (default: 1).
num_heads (int): parallel attention heads (default: 4)
dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
qkv_bias (bool): add bias to q, k, and v projections
""" """
def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False): def __init__(
self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None,
qk_ratio=1.0, qkv_bias=False):
super().__init__() super().__init__()
assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required'
dim_out = dim_out or dim dim_out = dim_out or dim
assert dim_out % num_heads == 0 assert dim_out % num_heads == 0
self.num_heads = num_heads self.num_heads = num_heads
self.dim_out = dim_out self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
self.dim_head = dim_out // num_heads self.dim_head_v = dim_out // self.num_heads
self.scale = self.dim_head ** -0.5 self.dim_out_qk = num_heads * self.dim_head_qk
self.dim_out_v = num_heads * self.dim_head_v
self.scale = self.dim_head_qk ** -0.5
self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias) self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
# NOTE I'm only supporting relative pos embedding for now # NOTE I'm only supporting relative pos embedding for now
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale) self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
trunc_normal_(self.pos_embed.height_rel, std=self.scale) trunc_normal_(self.pos_embed.height_rel, std=self.scale)
trunc_normal_(self.pos_embed.width_rel, std=self.scale) trunc_normal_(self.pos_embed.width_rel, std=self.scale)
@ -114,15 +135,20 @@ class BottleneckAttn(nn.Module):
assert H == self.pos_embed.height assert H == self.pos_embed.height
assert W == self.pos_embed.width assert W == self.pos_embed.width
x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2)
q, k, v = torch.split(x, self.num_heads, dim=1) # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v
# So, this is more verbose than if heads were before qkv splits, but throughput is not impacted.
q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
attn = (q @ k.transpose(-1, -2)) * self.scale attn = (q @ k) * self.scale
attn = attn + self.pos_embed(q) # B, num_heads, H * W, H * W attn = attn + self.pos_embed(q) # B * num_heads, H * W, H * W
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
out = self.pool(out) out = self.pool(out)
return out return out

@ -22,6 +22,7 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
@ -98,31 +99,62 @@ class HaloAttn(nn.Module):
Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones` Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
- https://arxiv.org/abs/2103.12731 - https://arxiv.org/abs/2103.12731
The internal dimensions of the attention module are controlled by the interaction of several arguments.
* the output dimension of the module is specified by dim_out, which falls back to input dim if not set
* the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
* the query and key (qk) dimensions are determined by
* num_heads * dim_head if dim_head is not None
* num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
* as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
Args:
dim (int): input dimension to the module
dim_out (int): output dimension of the module, same as dim if not set
feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda)
stride: output stride of the module, query downscaled if > 1 (default: 1).
num_heads: parallel attention heads (default: 8).
dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
block_size (int): size of blocks. (default: 8)
halo_size (int): size of halo overlap. (default: 3)
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
qkv_bias (bool) : add bias to q, k, and v projections
avg_down (bool): use average pool downsample instead of strided query blocks
""" """
def __init__( def __init__(
self, dim, dim_out=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3, qkv_bias=False): self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3,
qk_ratio=1.0, qkv_bias=False, avg_down=False):
super().__init__() super().__init__()
dim_out = dim_out or dim dim_out = dim_out or dim
assert dim_out % num_heads == 0 assert dim_out % num_heads == 0
self.stride = stride assert stride in (1, 2)
self.num_heads = num_heads self.num_heads = num_heads
self.dim_head_qk = dim_head or dim_out // num_heads self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
self.dim_head_v = dim_out // self.num_heads self.dim_head_v = dim_out // self.num_heads
self.dim_out_qk = num_heads * self.dim_head_qk self.dim_out_qk = num_heads * self.dim_head_qk
self.dim_out_v = num_heads * self.dim_head_v self.dim_out_v = num_heads * self.dim_head_v
self.block_size = block_size self.scale = self.dim_head_qk ** -0.5
self.block_size = self.block_size_ds = block_size
self.halo_size = halo_size self.halo_size = halo_size
self.win_size = block_size + halo_size * 2 # neighbourhood window size self.win_size = block_size + halo_size * 2 # neighbourhood window size
self.scale = self.dim_head_qk ** -0.5 self.block_stride = 1
use_avg_pool = False
if stride > 1:
use_avg_pool = avg_down or block_size % stride != 0
self.block_stride = 1 if use_avg_pool else stride
self.block_size_ds = self.block_size // self.block_stride
# FIXME not clear if this stride behaviour is what the paper intended # FIXME not clear if this stride behaviour is what the paper intended
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
# data in unfolded block form. I haven't wrapped my head around how that'd look. # data in unfolded block form. I haven't wrapped my head around how that'd look.
self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.stride, bias=qkv_bias) self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias)
self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias) self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias)
self.pos_embed = PosEmbedRel( self.pos_embed = PosEmbedRel(
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale) block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale)
self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity()
self.reset_parameters() self.reset_parameters()
@ -140,11 +172,12 @@ class HaloAttn(nn.Module):
num_h_blocks = H // self.block_size num_h_blocks = H // self.block_size
num_w_blocks = W // self.block_size num_w_blocks = W // self.block_size
num_blocks = num_h_blocks * num_w_blocks num_blocks = num_h_blocks * num_w_blocks
bs_stride = self.block_size // self.stride
q = self.q(x) q = self.q(x)
# unfold # unfold
q = q.reshape(-1, self.dim_head_qk, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4) q = q.reshape(
-1, self.dim_head_qk,
num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4)
# B, num_heads * dim_head * block_size ** 2, num_blocks # B, num_heads * dim_head * block_size ** 2, num_blocks
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3) q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3)
# B * num_heads, num_blocks, block_size ** 2, dim_head # B * num_heads, num_blocks, block_size ** 2, dim_head
@ -163,9 +196,11 @@ class HaloAttn(nn.Module):
out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks
# fold # fold
out = out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks) out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks)
out = out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_out_v, H // self.stride, W // self.stride) out = out.permute(0, 3, 1, 4, 2).contiguous().view(
# B, dim_out, H // stride, W // stride B, self.dim_out_v, H // self.block_stride, W // self.block_stride)
# B, dim_out, H // block_stride, W // block_stride
out = self.pool(out)
return out return out

@ -24,7 +24,7 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import to_2tuple from .helpers import to_2tuple, make_divisible
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
@ -44,28 +44,46 @@ class LambdaLayer(nn.Module):
- https://arxiv.org/abs/2102.08602 - https://arxiv.org/abs/2102.08602
NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
The internal dimensions of the lambda module are controlled via the interaction of several arguments.
* the output dimension of the module is specified by dim_out, which falls back to input dim if not set
* the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
* the query (q) and key (k) dimension are determined by
* dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None
* q = num_heads * dim_head, k = dim_head
* as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set
Args:
dim (int): input dimension to the module
dim_out (int): output dimension of the module, same as dim if not set
feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W
stride (int): output stride of the module, avg pool used if stride == 2
num_heads (int): parallel attention heads.
dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9)
qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
qkv_bias (bool): add bias to q, k, and v projections
""" """
def __init__( def __init__(
self, self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9,
dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): qk_ratio=1.0, qkv_bias=False):
super().__init__() super().__init__()
self.dim = dim dim_out = dim_out or dim
self.dim_out = dim_out or dim assert dim_out % num_heads == 0, ' should be divided by num_heads'
self.dim_k = dim_head # query depth 'k' self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
self.num_heads = num_heads self.num_heads = num_heads
assert self.dim_out % num_heads == 0, ' should be divided by num_heads' self.dim_v = dim_out // num_heads
self.dim_v = self.dim_out // num_heads # value depth 'v'
self.qkv = nn.Conv2d( self.qkv = nn.Conv2d(
dim, dim,
num_heads * dim_head + dim_head + self.dim_v, num_heads * self.dim_qk + self.dim_qk + self.dim_v,
kernel_size=1, bias=qkv_bias) kernel_size=1, bias=qkv_bias)
self.norm_q = nn.BatchNorm2d(num_heads * dim_head) self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk)
self.norm_v = nn.BatchNorm2d(self.dim_v) self.norm_v = nn.BatchNorm2d(self.dim_v)
if r is not None: if r is not None:
# local lambda convolution for pos # local lambda convolution for pos
self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0))
self.pos_emb = None self.pos_emb = None
self.rel_pos_indices = None self.rel_pos_indices = None
else: else:
@ -74,7 +92,7 @@ class LambdaLayer(nn.Module):
feat_size = to_2tuple(feat_size) feat_size = to_2tuple(feat_size)
rel_size = [2 * s - 1 for s in feat_size] rel_size = [2 * s - 1 for s in feat_size]
self.conv_lambda = None self.conv_lambda = None
self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_k)) self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk))
self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
@ -82,9 +100,9 @@ class LambdaLayer(nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
if self.conv_lambda is not None: if self.conv_lambda is not None:
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5)
if self.pos_emb is not None: if self.pos_emb is not None:
trunc_normal_(self.pos_emb, std=.02) trunc_normal_(self.pos_emb, std=.02)
@ -93,17 +111,17 @@ class LambdaLayer(nn.Module):
M = H * W M = H * W
qkv = self.qkv(x) qkv = self.qkv(x)
q, k, v = torch.split(qkv, [ q, k, v = torch.split(qkv, [
self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1) self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1)
q = self.norm_q(q).reshape(B, self.num_heads, self.dim_k, M).transpose(-1, -2) # B, num_heads, M, K q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K
v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
k = F.softmax(k.reshape(B, self.dim_k, M), dim=-1) # B, K, M k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
content_lam = k @ v # B, K, V content_lam = k @ v # B, K, V
content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
if self.pos_emb is None: if self.pos_emb is None:
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
else: else:
# FIXME relative pos embedding path not fully verified # FIXME relative pos embedding path not fully verified
pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)

Loading…
Cancel
Save