From 0b718a82c7ab3f48ed01978905500e5475d920a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E8=8F=9C?= <32389939+869019048@users.noreply.github.com> Date: Sat, 5 Jun 2021 16:58:36 +0800 Subject: [PATCH] one can implement CONTAINER-PAM and obtain a +0.5 improvement on ImageNet top-1 accuracy ##from :https://arxiv.org/pdf/2106.01401.pdf --- timm/models/vision_transformer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ff74d836..f3b8614c 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -140,6 +140,11 @@ class Attention(nn.Module): self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + + #Uncomment this line for Container-PAM + #self.static_a =nn.Parameter(torch.Tensor(1, num_heads, 1 + seq_l , 1 + seq_l)) + #trunc_normal_(self.static_a) + self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) @@ -151,6 +156,10 @@ class Attention(nn.Module): attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) + + #Uncomment this line for Container-PAM + #attn = attn + self.static_a + attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) @@ -669,4 +678,4 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs): """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) - return model \ No newline at end of file + return model