Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 7dd25eb402
commit 1327a6af4c

@ -508,11 +508,11 @@ class SpatialBlock(nn.Module):
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
x = self.cpe2(x)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)
if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W)
x = x.transpose(1, 2).view(B, C, H, W)
return x

Loading…
Cancel
Save