refactor: use fewer triplet models

pull/382/head
Ajay Uppili Arasanipalai 5 years ago
parent d17adadabb
commit 368834ab3c

@ -6,7 +6,7 @@ import torch
from .se import SEModule, EffectiveSEModule from .se import SEModule, EffectiveSEModule
from .eca import EcaModule, CecaModule from .eca import EcaModule, CecaModule
from .cbam import CbamModule, LightCbamModule from .cbam import CbamModule, LightCbamModule
from .triplet import TripletModule from .triplet import TripletAttention
def create_attn(attn_type, channels, **kwargs): def create_attn(attn_type, channels, **kwargs):
@ -27,7 +27,7 @@ def create_attn(attn_type, channels, **kwargs):
elif attn_type == 'lcbam': elif attn_type == 'lcbam':
module_cls = LightCbamModule module_cls = LightCbamModule
elif attn_type == 'triplet': elif attn_type == 'triplet':
module_cls = TripletModule module_cls = TripletAttention
else: else:
assert False, "Invalid attn module (%s)" % attn_type assert False, "Invalid attn module (%s)" % attn_type
elif isinstance(attn_type, bool): elif isinstance(attn_type, bool):

@ -48,9 +48,9 @@ class AttentionGate(nn.Module):
scale = torch.sigmoid_(x_out) scale = torch.sigmoid_(x_out)
return x * scale return x * scale
class TripletModule(nn.Module): class TripletAttention(nn.Module):
def __init__(self, no_spatial=False): def __init__(self, no_spatial=False):
super(TripletModule, self).__init__() super(TripletAttention, self).__init__()
self.cw = AttentionGate() self.cw = AttentionGate()
self.hc = AttentionGate() self.hc = AttentionGate()
self.no_spatial=no_spatial self.no_spatial=no_spatial

@ -227,57 +227,18 @@ default_cfgs = {
interpolation='bicubic'), interpolation='bicubic'),
# Triplet Attention ResNets # Triplet Attention ResNets
'triplet_resnet18': _cfg( 'triplet_resnet18d': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnet34': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnet50': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnet50tn': _cfg(
url='', url='',
interpolation='bicubic', interpolation='bicubic',
first_conv='conv1.0'), first_conv='conv1.0'),
'triplet_resnet101': _cfg( 'triplet_resnet50d': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnet152': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnet152d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'triplet_resnet152d_320': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
# Triplet Atention ResNeXts
'triplet_resnext26_32x4d': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnext26d_32x4d': _cfg(
url='',
interpolation='bicubic',
first_conv='conv1.0'),
'triplet_resnext26t_32x4d': _cfg(
url='', url='',
interpolation='bicubic', interpolation='bicubic',
first_conv='conv1.0'), first_conv='conv1.0'),
'triplet_resnext26tn_32x4d': _cfg( 'triplet_resnet101d': _cfg(
url='', url='',
interpolation='bicubic', interpolation='bicubic',
first_conv='conv1.0'), first_conv='conv1.0'),
'triplet_resnext50_32x4d': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnext101_32x4d': _cfg(
url='',
interpolation='bicubic'),
'triplet_resnext101_32x8d': _cfg(
url='',
interpolation='bicubic'),
# ResNets with anti-aliasing blur pool # ResNets with anti-aliasing blur pool
'resnetblur18': _cfg( 'resnetblur18': _cfg(
@ -1335,122 +1296,22 @@ def senet154(pretrained=False, **kwargs):
return _create_resnet('senet154', pretrained, **model_args) return _create_resnet('senet154', pretrained, **model_args)
@register_model @register_model
def triplet_resnet18(pretrained=False, **kwargs): def triplet_resnet18d(pretrained=False, **kwargs):
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='triplet'), **kwargs) model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet18', pretrained, **model_args) return _create_resnet('triplet_resnet18', pretrained, **model_args)
@register_model @register_model
def triplet_resnet34(pretrained=False, **kwargs): def triplet_resnet50d(pretrained=False, **kwargs):
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet34', pretrained, **model_args)
@register_model
def triplet_resnet50(pretrained=False, **kwargs):
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet50', pretrained, **model_args)
@register_model
def triplet_resnet50tn(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered_narrow', avg_down=True,
block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet50tn', pretrained, **model_args)
@register_model
def triplet_resnet101(pretrained=False, **kwargs):
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet101', pretrained, **model_args)
@register_model
def triplet_resnet152(pretrained=False, **kwargs):
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet152', pretrained, **model_args)
@register_model
def triplet_resnet152d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet152d', pretrained, **model_args)
@register_model
def triplet_resnet152d_320(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnet152d_320', pretrained, **model_args)
@register_model
def triplet_resnext26_32x4d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4,
block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext26_32x4d', pretrained, **model_args)
@register_model
def triplet_resnext26d_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt-26-D (with Triplet Attention) model.`
This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for
combination of deep stem and avg_pool in downsample.
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep', avg_down=True, block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext26d_32x4d', pretrained, **model_args)
@register_model
def triplet_resnext26t_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNet-26-T (with Triplet Attention) model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 48, 64 channels
in the deep stem.
"""
model_args = dict( model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext26t_32x4d', pretrained, **model_args)
@register_model
def triplet_resnext26tn_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt-26-TN (with Triplet Attention) model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered_narrow', avg_down=True, block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext26tn_32x4d', pretrained, **model_args)
@register_model
def triplet_resnext50_32x4d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext50_32x4d', pretrained, **model_args)
@register_model
def triplet_resnext101_32x4d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4,
block_args=dict(attn_layer='triplet'), **kwargs) block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext101_32x4d', pretrained, **model_args) return _create_resnet('triplet_resnet50', pretrained, **model_args)
@register_model @register_model
def triplet_resnext101_32x8d(pretrained=False, **kwargs): def triplet_resnet101d(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='triplet'), **kwargs) block_args=dict(attn_layer='triplet'), **kwargs)
return _create_resnet('triplet_resnext101_32x8d', pretrained, **model_args) return _create_resnet('triplet_resnet101', pretrained, **model_args)
Loading…
Cancel
Save