Add relative pos embed option to LambdaLayer, fix last transpose/reshape.

pull/880/head
Ross Wightman 3 years ago
parent d657e2cc0b
commit b49630a138

@ -62,7 +62,7 @@ default_cfgs = {
'lambda_resnet50ts': _cfg( 'lambda_resnet50ts': _cfg(
url='', url='',
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
'lambda_resnet26rt_256': _cfg( 'lambda_resnet26rpt_256': _cfg(
url='', url='',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
} }
@ -218,7 +218,7 @@ model_cfgs = dict(
self_attn_layer='lambda', self_attn_layer='lambda',
self_attn_kwargs=dict(r=9) self_attn_kwargs=dict(r=9)
), ),
lambda_resnet26rt_256=ByoModelCfg( lambda_resnet26rpt_256=ByoModelCfg(
blocks=( blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
@ -321,8 +321,8 @@ def lambda_resnet50ts(pretrained=False, **kwargs):
@register_model @register_model
def lambda_resnet26rt_256(pretrained=False, **kwargs): def lambda_resnet26rpt_256(pretrained=False, **kwargs):
""" Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages. """ Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages.
""" """
kwargs.setdefault('img_size', 256) kwargs.setdefault('img_size', 256)
return _create_byoanet('lambda_resnet26rt_256', pretrained=pretrained, **kwargs) return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs)

@ -24,18 +24,30 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import to_2tuple
from .weight_init import trunc_normal_ from .weight_init import trunc_normal_
def rel_pos_indices(size):
size = to_2tuple(size)
pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
rel_pos = pos[:, None, :] - pos[:, :, None]
rel_pos[0] += size[0] - 1
rel_pos[1] += size[1] - 1
return rel_pos # 2, H * W, H * W
class LambdaLayer(nn.Module): class LambdaLayer(nn.Module):
"""Lambda Layer w/ lambda conv position embedding """Lambda Layer
Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
- https://arxiv.org/abs/2102.08602 - https://arxiv.org/abs/2102.08602
NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
""" """
def __init__( def __init__(
self, self,
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.dim_out = dim_out or dim self.dim_out = dim_out or dim
@ -43,7 +55,6 @@ class LambdaLayer(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
assert self.dim_out % num_heads == 0, ' should be divided by num_heads' assert self.dim_out % num_heads == 0, ' should be divided by num_heads'
self.dim_v = self.dim_out // num_heads # value depth 'v' self.dim_v = self.dim_out // num_heads # value depth 'v'
self.r = r # relative position neighbourhood (lambda conv kernel size)
self.qkv = nn.Conv2d( self.qkv = nn.Conv2d(
dim, dim,
@ -52,8 +63,19 @@ class LambdaLayer(nn.Module):
self.norm_q = nn.BatchNorm2d(num_heads * dim_head) self.norm_q = nn.BatchNorm2d(num_heads * dim_head)
self.norm_v = nn.BatchNorm2d(self.dim_v) self.norm_v = nn.BatchNorm2d(self.dim_v)
# NOTE currently only supporting the local lambda convolutions for positional if r is not None:
self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) # local lambda convolution for pos
self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0))
self.pos_emb = None
self.rel_pos = None
else:
# relative pos embedding
assert feat_size is not None
feat_size = to_2tuple(feat_size)
rel_size = [2 * s - 1 for s in feat_size]
self.conv_lambda = None
self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_k))
self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
@ -61,12 +83,14 @@ class LambdaLayer(nn.Module):
def reset_parameters(self): def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) trunc_normal_(self.qkv.weight, std=self.dim ** -0.5)
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) if self.conv_lambda is not None:
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)
if self.pos_emb is not None:
trunc_normal_(self.pos_emb, std=.02)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
M = H * W M = H * W
qkv = self.qkv(x) qkv = self.qkv(x)
q, k, v = torch.split(qkv, [ q, k, v = torch.split(qkv, [
self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1) self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1)
@ -77,10 +101,15 @@ class LambdaLayer(nn.Module):
content_lam = k @ v # B, K, V content_lam = k @ v # B, K, V
content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K if self.pos_emb is None:
position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
else:
# FIXME relative pos embedding path not fully verified
pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)
position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V
position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V
out = (content_out + position_out).transpose(3, 1).reshape(B, C, H, W) # B, C (num_heads * V), H, W out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W
out = self.pool(out) out = self.pool(out)
return out return out

Loading…
Cancel
Save