|
|
|
@ -53,9 +53,9 @@ class TripletAttention(nn.Module):
|
|
|
|
|
super(TripletAttention, self).__init__()
|
|
|
|
|
self.cw = AttentionGate()
|
|
|
|
|
self.hc = AttentionGate()
|
|
|
|
|
self.no_spatial=no_spatial
|
|
|
|
|
if not no_spatial:
|
|
|
|
|
self.hw = AttentionGate()
|
|
|
|
|
self.hw = nn.Identity() if no_spatial else AttentionGate()
|
|
|
|
|
self.no_spatial = no_spatial
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x_perm1 = x.permute(0,2,1,3).contiguous()
|
|
|
|
|
x_out1 = self.cw(x_perm1)
|
|
|
|
@ -63,9 +63,7 @@ class TripletAttention(nn.Module):
|
|
|
|
|
x_perm2 = x.permute(0,3,2,1).contiguous()
|
|
|
|
|
x_out2 = self.hc(x_perm2)
|
|
|
|
|
x_out21 = x_out2.permute(0,3,2,1).contiguous()
|
|
|
|
|
if not self.no_spatial:
|
|
|
|
|
x_out = self.hw(x)
|
|
|
|
|
x_out = (1/3) * (x_out + x_out11 + x_out21)
|
|
|
|
|
else:
|
|
|
|
|
x_out = 0.5 * (x_out11 + x_out21)
|
|
|
|
|
x_out = self.hw(x)
|
|
|
|
|
|
|
|
|
|
x_out = (1/2) * (x_out11 + x_out21) if self.no_spatial else (1/3) * (x_out + x_out11 + x_out21)
|
|
|
|
|
return x_out
|