You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
71 lines
2.7 KiB
71 lines
2.7 KiB
""" Triplet Attention Module
|
|
|
|
Implementation of triplet attention module from https://arxiv.org/abs/2010.03045
|
|
|
|
(slightly) Modified from official implementation: https://github.com/LandskapeAI/triplet-attention
|
|
|
|
Original license:
|
|
|
|
MIT License
|
|
|
|
Copyright (c) 2020 LandskapeAI
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
of this software and associated documentation files (the "Software"), to deal
|
|
in the Software without restriction, including without limitation the rights
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
copies of the Software, and to permit persons to whom the Software is
|
|
furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in all
|
|
copies or substantial portions of the Software.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
SOFTWARE.
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn as nn
|
|
from .conv_bn_act import ConvBnAct
|
|
|
|
class ZPool(nn.Module):
|
|
def forward(self, x):
|
|
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
|
|
|
|
class AttentionGate(nn.Module):
|
|
def __init__(self, kernel_size=7):
|
|
super(AttentionGate, self).__init__()
|
|
self.zpool = ZPool()
|
|
self.conv = ConvBnAct(2, 1, kernel_size=kernel_size, stride=1, padding=(kernel_size-1) // 2, apply_act=False)
|
|
|
|
def forward(self, x):
|
|
x_out = self.conv(self.zpool(x))
|
|
scale = torch.sigmoid_(x_out)
|
|
return x * scale
|
|
|
|
class TripletAttention(nn.Module):
|
|
def __init__(self, no_spatial=False):
|
|
super(TripletAttention, self).__init__()
|
|
self.cw = AttentionGate()
|
|
self.hc = AttentionGate()
|
|
self.no_spatial=no_spatial
|
|
if not no_spatial:
|
|
self.hw = AttentionGate()
|
|
def forward(self, x):
|
|
x_perm1 = x.permute(0,2,1,3).contiguous()
|
|
x_out1 = self.cw(x_perm1)
|
|
x_out11 = x_out1.permute(0,2,1,3).contiguous()
|
|
x_perm2 = x.permute(0,3,2,1).contiguous()
|
|
x_out2 = self.hc(x_perm2)
|
|
x_out21 = x_out2.permute(0,3,2,1).contiguous()
|
|
if not self.no_spatial:
|
|
x_out = self.hw(x)
|
|
x_out = (1/3) * (x_out + x_out11 + x_out21)
|
|
else:
|
|
x_out = 0.5 * (x_out11 + x_out21)
|
|
return x_out |