diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ff74d836..f3b8614c 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -140,6 +140,11 @@ class Attention(nn.Module): self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + + #Uncomment this line for Container-PAM + #self.static_a =nn.Parameter(torch.Tensor(1, num_heads, 1 + seq_l , 1 + seq_l)) + #trunc_normal_(self.static_a) + self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) @@ -151,6 +156,10 @@ class Attention(nn.Module): attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) + + #Uncomment this line for Container-PAM + #attn = attn + self.static_a + attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) @@ -669,4 +678,4 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs): """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) - return model \ No newline at end of file + return model