refactor: use fewer triplet models

pull/382/head
Ajay Uppili Arasanipalai 5 years ago
parent d17adadabb
commit 368834ab3c

@ -6,7 +6,7 @@ import torch
from .se import SEModule, EffectiveSEModule
from .eca import EcaModule, CecaModule
from .cbam import CbamModule, LightCbamModule
from .triplet import TripletModule
from .triplet import TripletAttention
def create_attn(attn_type, channels, **kwargs):
@ -27,7 +27,7 @@ def create_attn(attn_type, channels, **kwargs):
elif attn_type == 'lcbam':
module_cls = LightCbamModule
elif attn_type == 'triplet':
module_cls = TripletModule
module_cls = TripletAttention
else:
assert False, "Invalid attn module (%s)" % attn_type
elif isinstance(attn_type, bool):

@ -48,9 +48,9 @@ class AttentionGate(nn.Module):
scale = torch.sigmoid_(x_out)
return x * scale
class TripletModule(nn.Module):
class TripletAttention(nn.Module):
def __init__(self, no_spatial=False):
super(TripletModule, self).__init__()
super(TripletAttention, self).__init__()
self.cw = AttentionGate()
self.hc = AttentionGate()
self.no_spatial=no_spatial

@ -227,57 +227,18 @@ default_cfgs = {
interpolation='bicubic'),
# Triplet Attention ResNets
'triplet_resnet18': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnet34': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnet50': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnet50tn': _cfg(
'triplet_resnet18d': _cfg(
url='',
interpolation='bicubic',
first_conv='conv1.0'),
'triplet_resnet101': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnet152': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnet152d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'triplet_resnet152d_320': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
# Triplet Atention ResNeXts
'triplet_resnext26_32x4d': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnext26d_32x4d': _cfg(
url='',
interpolation='bicubic',
first_conv='conv1.0'),
'triplet_resnext26t_32x4d': _cfg(
'triplet_resnet50d': _cfg(
url='',
interpolation='bicubic',
first_conv='conv1.0'),
'triplet_resnext26tn_32x4d': _cfg(
'triplet_resnet101d': _cfg(
url='',
interpolation='bicubic',
first_conv='conv1.0'),
'triplet_resnext50_32x4d': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnext101_32x4d': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnext101_32x8d': _cfg(
url='',
interpolation='bicubic'),
# ResNets with anti-aliasing blur pool
'resnetblur18': _cfg(
@ -1335,122 +1296,22 @@ def senet154(pretrained=False, **kwargs):
return _create_resnet('senet154', pretrained, **model_args)
@register_model
def triplet_resnet18(pretrained=False, **kwargs):
def triplet_resnet18d(pretrained=False, **kwargs):
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet18', pretrained, **model_args)
@register_model
def triplet_resnet34(pretrained=False, **kwargs):
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet34', pretrained, **model_args)
@register_model
def triplet_resnet50(pretrained=False, **kwargs):
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet50', pretrained, **model_args)
@register_model
def triplet_resnet50tn(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet50tn', pretrained, **model_args)
@register_model
def triplet_resnet101(pretrained=False, **kwargs):
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet101', pretrained, **model_args)
@register_model
def triplet_resnet152(pretrained=False, **kwargs):
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet152', pretrained, **model_args)
@register_model
def triplet_resnet152d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet152d', pretrained, **model_args)
@register_model
def triplet_resnet152d_320(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet152d_320', pretrained, **model_args)
@register_model
def triplet_resnext26_32x4d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4,
block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext26_32x4d', pretrained, **model_args)
@register_model
def triplet_resnext26d_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt-26-D (with Triplet Attention) model.`
This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for
combination of deep stem and avg_pool in downsample.
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep', avg_down=True, block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext26d_32x4d', pretrained, **model_args)
@register_model
def triplet_resnext26t_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNet-26-T (with Triplet Attention) model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels
in the deep stem.
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext26t_32x4d', pretrained, **model_args)
@register_model
def triplet_resnext26tn_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt-26-TN (with Triplet Attention) model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext26tn_32x4d', pretrained, **model_args)
@register_model
def triplet_resnext50_32x4d(pretrained=False, **kwargs):
def triplet_resnet50d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext50_32x4d', pretrained, **model_args)
@register_model
def triplet_resnext101_32x4d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4,
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext101_32x4d', pretrained, **model_args)
return _create_resnet('triplet_resnet50', pretrained, **model_args)
@register_model
def triplet_resnext101_32x8d(pretrained=False, **kwargs):
def triplet_resnet101d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext101_32x8d', pretrained, **model_args)
return _create_resnet('triplet_resnet101', pretrained, **model_args)
Loading…
Cancel
Save