diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index 5c7be0d6..056813ef 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -62,7 +62,7 @@ default_cfgs = { 'lambda_resnet50ts': _cfg( url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), - 'lambda_resnet26rt_256': _cfg( + 'lambda_resnet26rpt_256': _cfg( url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), } @@ -218,7 +218,7 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict(r=9) ), - lambda_resnet26rt_256=ByoModelCfg( + lambda_resnet26rpt_256=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), @@ -321,8 +321,8 @@ def lambda_resnet50ts(pretrained=False, **kwargs): @register_model -def lambda_resnet26rt_256(pretrained=False, **kwargs): +def lambda_resnet26rpt_256(pretrained=False, **kwargs): """ Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages. """ kwargs.setdefault('img_size', 256) - return _create_byoanet('lambda_resnet26rt_256', pretrained=pretrained, **kwargs) + return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs) diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index d298c1aa..fd174855 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -24,18 +24,30 @@ import torch from torch import nn import torch.nn.functional as F +from .helpers import to_2tuple from .weight_init import trunc_normal_ +def rel_pos_indices(size): + size = to_2tuple(size) + pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1) + rel_pos = pos[:, None, :] - pos[:, :, None] + rel_pos[0] += size[0] - 1 + rel_pos[1] += size[1] - 1 + return rel_pos # 2, H * W, H * W + + class LambdaLayer(nn.Module): - """Lambda Layer w/ lambda conv position embedding + """Lambda Layer Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention` - https://arxiv.org/abs/2102.08602 + + NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add. """ def __init__( self, - dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): + dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): super().__init__() self.dim = dim self.dim_out = dim_out or dim @@ -43,7 +55,6 @@ class LambdaLayer(nn.Module): self.num_heads = num_heads assert self.dim_out % num_heads == 0, ' should be divided by num_heads' self.dim_v = self.dim_out // num_heads # value depth 'v' - self.r = r # relative position neighbourhood (lambda conv kernel size) self.qkv = nn.Conv2d( dim, @@ -52,8 +63,19 @@ class LambdaLayer(nn.Module): self.norm_q = nn.BatchNorm2d(num_heads * dim_head) self.norm_v = nn.BatchNorm2d(self.dim_v) - # NOTE currently only supporting the local lambda convolutions for positional - self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) + if r is not None: + # local lambda convolution for pos + self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0)) + self.pos_emb = None + self.rel_pos = None + else: + # relative pos embedding + assert feat_size is not None + feat_size = to_2tuple(feat_size) + rel_size = [2 * s - 1 for s in feat_size] + self.conv_lambda = None + self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_k)) + self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() @@ -61,12 +83,14 @@ class LambdaLayer(nn.Module): def reset_parameters(self): trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) - trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + if self.conv_lambda is not None: + trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + if self.pos_emb is not None: + trunc_normal_(self.pos_emb, std=.02) def forward(self, x): B, C, H, W = x.shape M = H * W - qkv = self.qkv(x) q, k, v = torch.split(qkv, [ self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1) @@ -77,10 +101,15 @@ class LambdaLayer(nn.Module): content_lam = k @ v # B, K, V content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V - position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K - position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + if self.pos_emb is None: + position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K + position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V + else: + # FIXME relative pos embedding path not fully verified + pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1) + position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V - out = (content_out + position_out).transpose(3, 1).reshape(B, C, H, W) # B, C (num_heads * V), H, W + out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W out = self.pool(out) return out