|
|
@ -48,9 +48,9 @@ class AttentionGate(nn.Module):
|
|
|
|
scale = torch.sigmoid_(x_out)
|
|
|
|
scale = torch.sigmoid_(x_out)
|
|
|
|
return x * scale
|
|
|
|
return x * scale
|
|
|
|
|
|
|
|
|
|
|
|
class TripletAttention(nn.Module):
|
|
|
|
class TripletModule(nn.Module):
|
|
|
|
def __init__(self, no_spatial=False):
|
|
|
|
def __init__(self, no_spatial=False):
|
|
|
|
super(TripletAttention, self).__init__()
|
|
|
|
super(TripletModule, self).__init__()
|
|
|
|
self.cw = AttentionGate()
|
|
|
|
self.cw = AttentionGate()
|
|
|
|
self.hc = AttentionGate()
|
|
|
|
self.hc = AttentionGate()
|
|
|
|
self.no_spatial=no_spatial
|
|
|
|
self.no_spatial=no_spatial
|
|
|
|