From 072155951104230c2b5f3bbfb31acc694ee2fa0a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 4 May 2021 21:40:39 -0700 Subject: [PATCH] Improved (hopefully) init for SA/SA-like layers used in ByoaNets --- timm/models/byoanet.py | 2 ++ timm/models/layers/bottleneck_attn.py | 6 ++++++ timm/models/layers/halo_attn.py | 9 +++++++++ timm/models/layers/lambda_layer.py | 6 ++++++ timm/models/layers/swin_attn.py | 6 +++++- 5 files changed, 28 insertions(+), 1 deletion(-) diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index df88535d..da9e513b 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -294,6 +294,8 @@ class SelfAttnBlock(nn.Module): def init_weights(self, zero_init_last_bn=False): if zero_init_last_bn: nn.init.zeros_(self.conv3_1x1.bn.weight) + if hasattr(self.self_attn, 'reset_parameters'): + self.self_attn.reset_parameters() def forward(self, x): shortcut = self.shortcut(x) diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py index 0bb0e27b..9604e8a6 100644 --- a/timm/models/layers/bottleneck_attn.py +++ b/timm/models/layers/bottleneck_attn.py @@ -21,6 +21,7 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import to_2tuple +from .weight_init import trunc_normal_ def rel_logits_1d(q, rel_k, permute_mask: List[int]): @@ -101,6 +102,11 @@ class BottleneckAttn(nn.Module): self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + def forward(self, x): B, C, H, W = x.shape assert H == self.pos_embed.height and W == self.pos_embed.width diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 8452aa94..87cae895 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -25,6 +25,8 @@ import torch from torch import nn import torch.nn.functional as F +from .weight_init import trunc_normal_ + def rel_logits_1d(q, rel_k, permute_mask: List[int]): """ Compute relative logits along one dimension @@ -124,6 +126,13 @@ class HaloAttn(nn.Module): self.pos_embed = PosEmbedRel( block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale) + def reset_parameters(self): + std = self.q.weight.shape[1] ** -0.5 # fan-in + trunc_normal_(self.q.weight, std=std) + trunc_normal_(self.kv.weight, std=std) + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + def forward(self, x): B, C, H, W = x.shape assert H % self.block_size == 0 and W % self.block_size == 0 diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py index c89982af..2d1027a1 100644 --- a/timm/models/layers/lambda_layer.py +++ b/timm/models/layers/lambda_layer.py @@ -24,6 +24,7 @@ import torch from torch import nn import torch.nn.functional as F +from .weight_init import trunc_normal_ class LambdaLayer(nn.Module): @@ -36,6 +37,7 @@ class LambdaLayer(nn.Module): self, dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): super().__init__() + self.dim = dim self.dim_out = dim_out or dim self.dim_k = dim_head # query depth 'k' self.num_heads = num_heads @@ -55,6 +57,10 @@ class LambdaLayer(nn.Module): self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.dim ** -0.5) + trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5) + def forward(self, x): B, C, H, W = x.shape M = H * W diff --git a/timm/models/layers/swin_attn.py b/timm/models/layers/swin_attn.py index 46dacb62..02131bbc 100644 --- a/timm/models/layers/swin_attn.py +++ b/timm/models/layers/swin_attn.py @@ -107,6 +107,7 @@ class WindowAttention(nn.Module): self.relative_position_bias_table = nn.Parameter( # 2 * Wh - 1 * 2 * Ww - 1, nH torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads)) + trunc_normal_(self.relative_position_bias_table, std=.02) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.win_size) @@ -120,13 +121,16 @@ class WindowAttention(nn.Module): relative_coords[:, :, 0] *= 2 * self.win_size - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) - trunc_normal_(self.relative_position_bias_table, std=.02) self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.softmax = nn.Softmax(dim=-1) self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.relative_position_bias_table, std=.02) + def forward(self, x): B, C, H, W = x.shape x = x.permute(0, 2, 3, 1)