|
|
|
@ -166,7 +166,7 @@ class Attention(nn.Module):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
@ -663,7 +663,7 @@ def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
|
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
|
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -674,7 +674,7 @@ def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
|
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
|
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -685,7 +685,7 @@ def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
|
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
|
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|