Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 8bb6b3687b
commit 718508309f

@ -309,7 +309,6 @@ class SpatialBlock(nn.Module):
H, W = size H, W = size
B, L, C = x.shape B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = self.cpe[0](x, size) shortcut = self.cpe[0](x, size)
x = self.norm1(shortcut) x = self.norm1(shortcut)
@ -334,7 +333,7 @@ class SpatialBlock(nn.Module):
C) C)
x = window_reverse(attn_windows, self.window_size, Hp, Wp) x = window_reverse(attn_windows, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0: #if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous() x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C) x = x.view(B, H * W, C)

Loading…
Cancel
Save