fix: torchscript

pull/382/head
iyaja 5 years ago
parent 225d45afac
commit 499790e117

@ -53,9 +53,9 @@ class TripletAttention(nn.Module):
super(TripletAttention, self).__init__()
self.cw = AttentionGate()
self.hc = AttentionGate()
self.no_spatial=no_spatial
if not no_spatial:
self.hw = AttentionGate()
self.hw = nn.Identity() if no_spatial else AttentionGate()
self.no_spatial = no_spatial
def forward(self, x):
x_perm1 = x.permute(0,2,1,3).contiguous()
x_out1 = self.cw(x_perm1)
@ -63,9 +63,7 @@ class TripletAttention(nn.Module):
x_perm2 = x.permute(0,3,2,1).contiguous()
x_out2 = self.hc(x_perm2)
x_out21 = x_out2.permute(0,3,2,1).contiguous()
if not self.no_spatial:
x_out = self.hw(x)
x_out = (1/3) * (x_out + x_out11 + x_out21)
else:
x_out = 0.5 * (x_out11 + x_out21)
x_out = self.hw(x)
x_out = (1/2) * (x_out11 + x_out21) if self.no_spatial else (1/3) * (x_out + x_out11 + x_out21)
return x_out
Loading…
Cancel
Save