|
|
|
@ -172,7 +172,7 @@ class WindowAttention(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
B_, N, C = x.shape
|
|
|
|
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
|
|
|
|
|
|
q = q * self.scale
|
|
|
|
|
attn = (q @ k.transpose(-2, -1))
|
|
|
|
@ -649,4 +649,4 @@ def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
|
|
|
|
return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|