From 06b86a3f7a2d6960bac9cde8c13068f51edaa30b Mon Sep 17 00:00:00 2001 From: iyaja Date: Mon, 25 Jan 2021 00:32:36 +0530 Subject: [PATCH] feat: add triplet attention layer --- timm/models/layers/create_attn.py | 3 + timm/models/layers/triplet.py | 71 ++++++++++++++++++ timm/models/resnet.py | 121 ++++++++++++++++++++++++++++++ 3 files changed, 195 insertions(+) create mode 100644 timm/models/layers/triplet.py diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index 59ecd858..a69bd5e5 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -6,6 +6,7 @@ import torch from .se import SEModule, EffectiveSEModule from .eca import EcaModule, CecaModule from .cbam import CbamModule, LightCbamModule +from .triplet import TripletModule def create_attn(attn_type, channels, **kwargs): @@ -25,6 +26,8 @@ def create_attn(attn_type, channels, **kwargs): module_cls = CbamModule elif attn_type == 'lcbam': module_cls = LightCbamModule + elif attn_type == 'triplet': + module_cls = TripletModule else: assert False, "Invalid attn module (%s)" % attn_type elif isinstance(attn_type, bool): diff --git a/timm/models/layers/triplet.py b/timm/models/layers/triplet.py new file mode 100644 index 00000000..efbc8450 --- /dev/null +++ b/timm/models/layers/triplet.py @@ -0,0 +1,71 @@ +""" Triplet Attention Module + +Implementation of triplet attention module from https://arxiv.org/abs/2010.03045 + +(slightly) Modified from official implementation: https://github.com/LandskapeAI/triplet-attention + +Original license: + +MIT License + +Copyright (c) 2020 LandskapeAI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import torch +from torch import nn as nn +from .conv_bn_act import ConvBnAct + +class ZPool(nn.Module): + def forward(self, x): + return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) + +class AttentionGate(nn.Module): + def __init__(self, kernel_size=7): + super(AttentionGate, self).__init__() + self.zpool = ZPool() + self.conv = ConvBnAct(2, 1, kernel_size=kernel_size, stride=1, padding=(kernel_size-1) // 2, apply_act=False) + + def forward(self, x): + x_out = self.conv(self.zpool(x)) + scale = torch.sigmoid_(x_out) + return x * scale + +class TripletAttention(nn.Module): + def __init__(self, no_spatial=False): + super(TripletAttention, self).__init__() + self.cw = AttentionGate() + self.hc = AttentionGate() + self.no_spatial=no_spatial + if not no_spatial: + self.hw = AttentionGate() + def forward(self, x): + x_perm1 = x.permute(0,2,1,3).contiguous() + x_out1 = self.cw(x_perm1) + x_out11 = x_out1.permute(0,2,1,3).contiguous() + x_perm2 = x.permute(0,3,2,1).contiguous() + x_out2 = self.hc(x_perm2) + x_out21 = x_out2.permute(0,3,2,1).contiguous() + if not self.no_spatial: + x_out = self.hw(x) + x_out = (1/3) * (x_out + x_out11 + x_out21) + else: + x_out = 0.5 * (x_out11 + x_out21) + return x_out \ No newline at end of file diff --git a/timm/models/resnet.py b/timm/models/resnet.py index be0652bf..bd17df7a 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -1280,3 +1280,124 @@ def senet154(pretrained=False, **kwargs): block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs) return _create_resnet('senet154', pretrained, **model_args) + +@register_model +def triplet_resnet18(pretrained=False, **kwargs): + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnet18', pretrained, **model_args) + + +@register_model +def triplet_resnet34(pretrained=False, **kwargs): + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnet34', pretrained, **model_args) + + +@register_model +def triplet_resnet50(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnet50', pretrained, **model_args) + + +@register_model +def triplet_resnet50tn(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, + block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnet50tn', pretrained, **model_args) + + +@register_model +def triplet_resnet101(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnet101', pretrained, **model_args) + + +@register_model +def triplet_resnet152(pretrained=False, **kwargs): + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnet152', pretrained, **model_args) + + +@register_model +def triplet_resnet152d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnet152d', pretrained, **model_args) + + +@register_model +def triplet_resnet152d_320(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnet152d_320', pretrained, **model_args) + + +@register_model +def triplet_resnext26_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, + block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnext26_32x4d', pretrained, **model_args) + + +@register_model +def triplet_resnext26d_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-26-D (with Triplet Attention) model.` + This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for + combination of deep stem and avg_pool in downsample. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep', avg_down=True, block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnext26d_32x4d', pretrained, **model_args) + + +@register_model +def triplet_resnext26t_32x4d(pretrained=False, **kwargs): + """Constructs a ResNet-26-T (with Triplet Attention) model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels + in the deep stem. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnext26t_32x4d', pretrained, **model_args) + + +@register_model +def triplet_resnext26tn_32x4d(pretrained=False, **kwargs): + """Constructs a ResNeXt-26-TN (with Triplet Attention) model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. + """ + model_args = dict( + block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, + stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnext26tn_32x4d', pretrained, **model_args) + + +@register_model +def triplet_resnext50_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnext50_32x4d', pretrained, **model_args) + + +@register_model +def triplet_resnext101_32x4d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, + block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnext101_32x4d', pretrained, **model_args) + + +@register_model +def triplet_resnext101_32x8d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + block_args=dict(attn_layer='triplet'), **kwargs) + return _create_resnet('triplet_resnext101_32x8d', pretrained, **model_args) \ No newline at end of file