From 499790e117b2c8c1b57780b73d16c28b84db509e Mon Sep 17 00:00:00 2001 From: iyaja Date: Sun, 31 Jan 2021 14:30:56 +0530 Subject: [PATCH] fix: torchscript --- timm/models/layers/triplet.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/timm/models/layers/triplet.py b/timm/models/layers/triplet.py index efbc8450..fcc4781f 100644 --- a/timm/models/layers/triplet.py +++ b/timm/models/layers/triplet.py @@ -53,9 +53,9 @@ class TripletAttention(nn.Module): super(TripletAttention, self).__init__() self.cw = AttentionGate() self.hc = AttentionGate() - self.no_spatial=no_spatial - if not no_spatial: - self.hw = AttentionGate() + self.hw = nn.Identity() if no_spatial else AttentionGate() + self.no_spatial = no_spatial + def forward(self, x): x_perm1 = x.permute(0,2,1,3).contiguous() x_out1 = self.cw(x_perm1) @@ -63,9 +63,7 @@ class TripletAttention(nn.Module): x_perm2 = x.permute(0,3,2,1).contiguous() x_out2 = self.hc(x_perm2) x_out21 = x_out2.permute(0,3,2,1).contiguous() - if not self.no_spatial: - x_out = self.hw(x) - x_out = (1/3) * (x_out + x_out11 + x_out21) - else: - x_out = 0.5 * (x_out11 + x_out21) + x_out = self.hw(x) + + x_out = (1/2) * (x_out11 + x_out21) if self.no_spatial else (1/3) * (x_out + x_out11 + x_out21) return x_out \ No newline at end of file