Improved (hopefully) init for SA/SA-like layers used in ByoaNets

pull/612/head
Ross Wightman 4 years ago
parent d5473c17f7
commit 0721559511

@ -294,6 +294,8 @@ class SelfAttnBlock(nn.Module):
def init_weights(self, zero_init_last_bn=False): def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn: if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight) nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
def forward(self, x): def forward(self, x):
shortcut = self.shortcut(x) shortcut = self.shortcut(x)

@ -21,6 +21,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import to_2tuple from .helpers import to_2tuple
from .weight_init import trunc_normal_
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):
@ -101,6 +102,11 @@ class BottleneckAttn(nn.Module):
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H == self.pos_embed.height and W == self.pos_embed.width assert H == self.pos_embed.height and W == self.pos_embed.width

@ -25,6 +25,8 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from .weight_init import trunc_normal_
def rel_logits_1d(q, rel_k, permute_mask: List[int]): def rel_logits_1d(q, rel_k, permute_mask: List[int]):
""" Compute relative logits along one dimension """ Compute relative logits along one dimension
@ -124,6 +126,13 @@ class HaloAttn(nn.Module):
self.pos_embed = PosEmbedRel( self.pos_embed = PosEmbedRel(
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale) block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
def reset_parameters(self):
std = self.q.weight.shape[1] ** -0.5 # fan-in
trunc_normal_(self.q.weight, std=std)
trunc_normal_(self.kv.weight, std=std)
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
assert H % self.block_size == 0 and W % self.block_size == 0 assert H % self.block_size == 0 and W % self.block_size == 0

@ -24,6 +24,7 @@ import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from .weight_init import trunc_normal_
class LambdaLayer(nn.Module): class LambdaLayer(nn.Module):
@ -36,6 +37,7 @@ class LambdaLayer(nn.Module):
self, self,
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False): dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
super().__init__() super().__init__()
self.dim = dim
self.dim_out = dim_out or dim self.dim_out = dim_out or dim
self.dim_k = dim_head # query depth 'k' self.dim_k = dim_head # query depth 'k'
self.num_heads = num_heads self.num_heads = num_heads
@ -55,6 +57,10 @@ class LambdaLayer(nn.Module):
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5)
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
M = H * W M = H * W

@ -107,6 +107,7 @@ class WindowAttention(nn.Module):
self.relative_position_bias_table = nn.Parameter( self.relative_position_bias_table = nn.Parameter(
# 2 * Wh - 1 * 2 * Ww - 1, nH # 2 * Wh - 1 * 2 * Ww - 1, nH
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads)) torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
trunc_normal_(self.relative_position_bias_table, std=.02)
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.win_size) coords_h = torch.arange(self.win_size)
@ -120,13 +121,16 @@ class WindowAttention(nn.Module):
relative_coords[:, :, 0] *= 2 * self.win_size - 1 relative_coords[:, :, 0] *= 2 * self.win_size - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index) self.register_buffer("relative_position_index", relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.softmax = nn.Softmax(dim=-1) self.softmax = nn.Softmax(dim=-1)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self, x): def forward(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1) x = x.permute(0, 2, 3, 1)

Loading…
Cancel
Save