diff --git a/timm/models/layers/triplet.py b/timm/models/layers/triplet.py index efbc8450..5e8f4d78 100644 --- a/timm/models/layers/triplet.py +++ b/timm/models/layers/triplet.py @@ -48,9 +48,9 @@ class AttentionGate(nn.Module): scale = torch.sigmoid_(x_out) return x * scale -class TripletAttention(nn.Module): +class TripletModule(nn.Module): def __init__(self, no_spatial=False): - super(TripletAttention, self).__init__() + super(TripletModule, self).__init__() self.cw = AttentionGate() self.hc = AttentionGate() self.no_spatial=no_spatial