@ -48,9 +48,9 @@ class AttentionGate(nn.Module):
scale = torch.sigmoid_(x_out)
return x * scale
class TripletAttention(nn.Module):
class TripletModule(nn.Module):
def __init__(self, no_spatial=False):
super(TripletAttention, self).__init__()
super(TripletModule, self).__init__()
self.cw = AttentionGate()
self.hc = AttentionGate()
self.no_spatial=no_spatial