@ -164,7 +164,7 @@ class WindowAttention(nn.Module):
torch.meshgrid(
[
relative_coords_h,
relative_coords_w
relative_coords_w,
],
indexing='ij',
)