|
|
|
@ -107,6 +107,7 @@ class WindowAttention(nn.Module):
|
|
|
|
|
self.relative_position_bias_table = nn.Parameter(
|
|
|
|
|
# 2 * Wh - 1 * 2 * Ww - 1, nH
|
|
|
|
|
torch.zeros((2 * self.win_size - 1) * (2 * self.win_size - 1), num_heads))
|
|
|
|
|
trunc_normal_(self.relative_position_bias_table, std=.02)
|
|
|
|
|
|
|
|
|
|
# get pair-wise relative position index for each token inside the window
|
|
|
|
|
coords_h = torch.arange(self.win_size)
|
|
|
|
@ -120,13 +121,16 @@ class WindowAttention(nn.Module):
|
|
|
|
|
relative_coords[:, :, 0] *= 2 * self.win_size - 1
|
|
|
|
|
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
|
|
|
|
self.register_buffer("relative_position_index", relative_position_index)
|
|
|
|
|
trunc_normal_(self.relative_position_bias_table, std=.02)
|
|
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(dim, self.dim_out * 3, bias=qkv_bias)
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
|
|
|
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
|
|
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
|
|
|
|
|
trunc_normal_(self.relative_position_bias_table, std=.02)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
x = x.permute(0, 2, 3, 1)
|
|
|
|
|