|
|
|
@ -140,6 +140,11 @@ class Attention(nn.Module):
|
|
|
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
|
|
|
|
|
|
|
|
#Uncomment this line for Container-PAM
|
|
|
|
|
#self.static_a =nn.Parameter(torch.Tensor(1, num_heads, 1 + seq_l , 1 + seq_l))
|
|
|
|
|
#trunc_normal_(self.static_a)
|
|
|
|
|
|
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
@ -151,6 +156,10 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
#Uncomment this line for Container-PAM
|
|
|
|
|
#attn = attn + self.static_a
|
|
|
|
|
|
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
|
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
|
|
@ -669,4 +678,4 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
return model
|
|
|
|
|