Compare commits

...

4 Commits

@ -94,6 +94,16 @@ default_cfgs = {
url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954), url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
'efficientnet_l2': _cfg( 'efficientnet_l2': _cfg(
url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961), url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
'efficientnet_eca_b0': _cfg(
url=''),
'efficientnet_eca_b1': _cfg(
url='',
input_size=(3, 240, 240), pool_size=(8, 8)),
'efficientnet_eca_b2': _cfg(
url='',
input_size=(3, 260, 260), pool_size=(9, 9)),
'xefficientnet_b0': _cfg(
url=''),
'efficientnet_es': _cfg( 'efficientnet_es': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'),
'efficientnet_em': _cfg( 'efficientnet_em': _cfg(
@ -234,7 +244,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'),
} }
_DEBUG = False _DEBUG = True
class EfficientNet(nn.Module): class EfficientNet(nn.Module):
@ -254,7 +264,7 @@ class EfficientNet(nn.Module):
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
channel_multiplier=1.0, channel_divisor=8, channel_min=None, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
super(EfficientNet, self).__init__() super(EfficientNet, self).__init__()
norm_kwargs = norm_kwargs or {} norm_kwargs = norm_kwargs or {}
@ -272,8 +282,8 @@ class EfficientNet(nn.Module):
# Middle stages (IR/ER/DS Blocks) # Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder( builder = EfficientNetBuilder(
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer,
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG) attn_layer, attn_kwargs, norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features self.feature_info = builder.features
self._in_chs = builder.in_chs self._in_chs = builder.in_chs
@ -334,7 +344,7 @@ class EfficientNetFeatures(nn.Module):
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None): attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(EfficientNetFeatures, self).__init__() super(EfficientNetFeatures, self).__init__()
norm_kwargs = norm_kwargs or {} norm_kwargs = norm_kwargs or {}
@ -354,8 +364,8 @@ class EfficientNetFeatures(nn.Module):
# Middle stages (IR/ER/DS Blocks) # Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder( builder = EfficientNetBuilder(
channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, attn_layer,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) attn_kwargs, norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features # builder provides info about feature channels for each block self.feature_info = builder.features # builder provides info about feature channels for each block
self._in_chs = builder.in_chs self._in_chs = builder.in_chs
@ -627,13 +637,61 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
""" """
arch_def = [ arch_def = [
['ds_r1_k3_s1_e1_c16_se0.25'], ['ds_r1_k3_s1_e1_c16'],
['ir_r2_k3_s2_e6_c24_se0.25'], ['ir_r2_k3_s2_e6_c24'],
['ir_r2_k5_s2_e6_c40_se0.25'], ['ir_r2_k5_s2_e6_c40'],
['ir_r3_k3_s2_e6_c80_se0.25'], ['ir_r3_k3_s2_e6_c80'],
['ir_r3_k5_s1_e6_c112_se0.25'], ['ir_r3_k5_s1_e6_c112'],
['ir_r4_k5_s2_e6_c192_se0.25'], ['ir_r4_k5_s2_e6_c192'],
['ir_r1_k3_s1_e6_c320_se0.25'], ['ir_r1_k3_s1_e6_c320'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
num_features=round_channels(1280, channel_multiplier, 8, None),
stem_size=32,
channel_multiplier=channel_multiplier,
act_layer=Swish,
attn_layer='sev2',
attn_kwargs=dict(se_ratio=0.25),
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_xefficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
"""Creates an EfficientNet model.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
Args:
channel_multiplier: multiplier to number of channels per layer
depth_multiplier: multiplier to number of repeats per stage
"""
arch_def = [
['ds_r1_k3_s1_e1_c16'],
['ir_r2_k3_s2_e6_c24'],
['ir_r1_k5_s2_e6_c40', 'xir_r1_k5_s1_e6_c40'],
['ir_r3_k3_s2_e6_c80'],
['xir_r3_k5_s1_e6_c112'],
['ir_r1_k5_s2_e6_c192', 'xir_r3_k5_s1_e6_c192'],
['xir_r1_k5_s1_e6_c320'],
] ]
model_kwargs = dict( model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier), block_args=decode_arch_def(arch_def, depth_multiplier),
@ -641,6 +699,8 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
stem_size=32, stem_size=32,
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
act_layer=Swish, act_layer=Swish,
attn_layer='sev2',
attn_kwargs=dict(se_ratio=0.25),
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
**kwargs, **kwargs,
) )
@ -707,6 +767,53 @@ def _gen_efficientnet_condconv(
return model return model
def _gen_efficientnet_eca(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
"""Creates an EfficientNet model w/ ECA attention instead of SE.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
Paper: https://arxiv.org/abs/1905.11946
EfficientNet params
name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
Args:
channel_multiplier: multiplier to number of channels per layer
depth_multiplier: multiplier to number of repeats per stage
"""
arch_def = [
['ds_r1_k3_s1_e1_c16'],
['ir_r2_k3_s2_e6_c24'],
['ir_r2_k5_s2_e6_c40'],
['ir_r3_k3_s2_e6_c80'],
['ir_r3_k5_s1_e6_c112'],
['ir_r4_k5_s2_e6_c192'],
['ir_r1_k3_s1_e6_c320'],
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
num_features=round_channels(1280, channel_multiplier, 8, None),
stem_size=32,
channel_multiplier=channel_multiplier,
act_layer=Swish,
attn_layer='eca',
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MixNet Small model. """Creates a MixNet Small model.
@ -980,6 +1087,42 @@ def efficientnet_l2(pretrained=False, **kwargs):
return model return model
@register_model
def efficientnet_eca_b0(pretrained=False, **kwargs):
""" EfficientNet-ECA-B0 """
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
model = _gen_efficientnet_eca(
'efficientnet_eca_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_eca_b1(pretrained=False, **kwargs):
""" EfficientNet-ECA-B1 """
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
model = _gen_efficientnet_eca(
'efficientnet_eca_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_eca_b2(pretrained=False, **kwargs):
""" EfficientNet-ECA-B2 """
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet_eca(
'efficientnet_eca_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
return model
@register_model
def xefficientnet_b0(pretrained=False, **kwargs):
""" XEfficientNet-B0 """
# NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
model = _gen_xefficientnet(
'xefficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
@register_model @register_model
def efficientnet_es(pretrained=False, **kwargs): def efficientnet_es(pretrained=False, **kwargs):
""" EfficientNet-Edge Small. """ """ EfficientNet-Edge Small. """

@ -1,8 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .layers.activations import sigmoid from .layers import create_conv2d, create_attn, drop_path
from .layers import create_conv2d, drop_path
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
@ -30,26 +29,21 @@ def resolve_bn_args(kwargs):
return bn_args return bn_args
_SE_ARGS_DEFAULT = dict( def resolve_attn_args(layer, kwargs, in_chs, act_layer=None):
gate_fn=sigmoid, attn_kwargs = kwargs.copy() if kwargs is not None else {}
act_layer=None, if isinstance(layer, nn.Module):
reduce_mid=False, is_se = 'SqueezeExciteV2' in layer.__name__
divisor=1) else:
is_se = layer == 'sev2'
if is_se:
def resolve_se_args(kwargs, in_chs, act_layer=None): # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
se_kwargs = kwargs.copy() if kwargs is not None else {} if not attn_kwargs.pop('reduce_mid', False):
# fill in args that aren't specified with the defaults attn_kwargs['reduced_base_chs'] = in_chs
for k, v in _SE_ARGS_DEFAULT.items(): # if act_layer it is not defined by attn kwargs, the containing block's act_layer will be used for attn
se_kwargs.setdefault(k, v) if attn_kwargs.get('act_layer', None) is None:
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch assert act_layer is not None
if not se_kwargs.pop('reduce_mid'): attn_kwargs['act_layer'] = act_layer
se_kwargs['reduced_base_chs'] = in_chs return attn_kwargs
# act_layer override, if it remains None, the containing block's act_layer will be used
if se_kwargs['act_layer'] is None:
assert act_layer is not None
se_kwargs['act_layer'] = act_layer
return se_kwargs
def make_divisible(v, divisor=8, min_value=None): def make_divisible(v, divisor=8, min_value=None):
@ -90,26 +84,6 @@ class ChannelShuffle(nn.Module):
) )
class SqueezeExcite(nn.Module):
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
super(SqueezeExcite, self).__init__()
self.gate_fn = gate_fn
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
x = x * self.gate_fn(x_se)
return x
class ConvBnAct(nn.Module): class ConvBnAct(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size, def __init__(self, in_chs, out_chs, kernel_size,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
@ -140,11 +114,10 @@ class DepthwiseSeparableConv(nn.Module):
""" """
def __init__(self, in_chs, out_chs, dw_kernel_size=3, def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, pw_kernel_size=1, pw_act=False, attn_layer=None, attn_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.): norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
super(DepthwiseSeparableConv, self).__init__() super(DepthwiseSeparableConv, self).__init__()
norm_kwargs = norm_kwargs or {} norm_kwargs = norm_kwargs or {}
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv self.has_pw_act = pw_act # activation after point-wise conv
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
@ -154,10 +127,10 @@ class DepthwiseSeparableConv(nn.Module):
self.bn1 = norm_layer(in_chs, **norm_kwargs) self.bn1 = norm_layer(in_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation # Attention block (Squeeze-Excitation, ECA, etc)
if has_se: if attn_layer is not None:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) self.se = create_attn(attn_layer, in_chs, **attn_kwargs)
else: else:
self.se = None self.se = None
@ -199,13 +172,12 @@ class InvertedResidual(nn.Module):
def __init__(self, in_chs, out_chs, dw_kernel_size=3, def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
conv_kwargs=None, drop_path_rate=0.): conv_kwargs=None, drop_path_rate=0.):
super(InvertedResidual, self).__init__() super(InvertedResidual, self).__init__()
norm_kwargs = norm_kwargs or {} norm_kwargs = norm_kwargs or {}
conv_kwargs = conv_kwargs or {} conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio) mid_chs = make_divisible(in_chs * exp_ratio)
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
@ -221,10 +193,10 @@ class InvertedResidual(nn.Module):
self.bn2 = norm_layer(mid_chs, **norm_kwargs) self.bn2 = norm_layer(mid_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
# Squeeze-and-excitation # Attention block (Squeeze-Excitation, ECA, etc)
if has_se: if attn_layer is not None:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) self.se = create_attn(attn_layer, mid_chs, **attn_kwargs)
else: else:
self.se = None self.se = None
@ -256,7 +228,162 @@ class InvertedResidual(nn.Module):
x = self.bn2(x) x = self.bn2(x)
x = self.act2(x) x = self.act2(x)
# Squeeze-and-excitation # Attention
if self.se is not None:
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn3(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
return x
class XDepthwiseSeparableConv(nn.Module):
""" DepthwiseSeparable block
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
pw_kernel_size=1, pw_act=False, attn_layer=None, attn_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
super(XDepthwiseSeparableConv, self).__init__()
norm_kwargs = norm_kwargs or {}
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.drop_path_rate = drop_path_rate
conv_kwargs = {}
self.conv_dw_2x2 = create_conv2d(
in_chs, in_chs, 2, stride=stride, dilation=dilation,
padding='same', depthwise=True, **conv_kwargs)
self.conv_dw_1xk = create_conv2d(
in_chs, in_chs, (1, dw_kernel_size), stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.conv_dw_kx1 = create_conv2d(
in_chs, in_chs, (dw_kernel_size, 1), stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.bn1 = norm_layer(in_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Attention block (Squeeze-Excitation, ECA, etc)
if attn_layer is not None:
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
self.se = create_attn(attn_layer, in_chs, **attn_kwargs)
else:
self.se = None
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
def feature_module(self, location):
# no expansion in this block, pre pw only feature extraction point
return 'conv_pw'
def feature_channels(self, location):
return self.conv_pw.in_channels
def forward(self, x):
residual = x
x = self.conv_dw_2x2(x)
x = self.conv_dw_1xk(x)
x = self.conv_dw_kx1(x)
x = self.bn1(x)
x = self.act1(x)
if self.se is not None:
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
x = self.act2(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
return x
class XInvertedResidual(nn.Module):
""" Inverted residual block w/ optional SE and CondConv routing"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, pad_shift=0,
attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
conv_kwargs=None, drop_path_rate=0.):
super(XInvertedResidual, self).__init__()
norm_kwargs = norm_kwargs or {}
conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio)
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate
# Point-wise expansion
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Depth-wise convolution
self.conv_dw_2x2 = create_conv2d(
mid_chs, mid_chs, 2, stride=stride, dilation=dilation,
padding='same', depthwise=True, pad_shift=pad_shift, **conv_kwargs)
self.conv_dw_1xk = create_conv2d(
mid_chs, mid_chs, (1, dw_kernel_size), stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.conv_dw_kx1 = create_conv2d(
mid_chs, mid_chs, (dw_kernel_size, 1), stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True)
# Attention block (Squeeze-Excitation, ECA, etc)
if attn_layer is not None:
attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
self.se = create_attn(attn_layer, mid_chs, **attn_kwargs)
else:
self.se = None
# Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.bn3 = norm_layer(out_chs, **norm_kwargs)
def feature_module(self, location):
if location == 'post_exp':
return 'act1'
return 'conv_pwl'
def feature_channels(self, location):
if location == 'post_exp':
return self.conv_pw.out_channels
# location == 'pre_pw'
return self.conv_pwl.in_channels
def forward(self, x):
residual = x
# Point-wise expansion
x = self.conv_pw(x)
x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = self.conv_dw_2x2(x)
x = self.conv_dw_1xk(x)
x = self.conv_dw_kx1(x)
x = self.bn2(x)
x = self.act2(x)
# Attention
if self.se is not None: if self.se is not None:
x = self.se(x) x = self.se(x)
@ -278,7 +405,7 @@ class CondConvResidual(InvertedResidual):
def __init__(self, in_chs, out_chs, dw_kernel_size=3, def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
num_experts=0, drop_path_rate=0.): num_experts=0, drop_path_rate=0.):
self.num_experts = num_experts self.num_experts = num_experts
@ -287,7 +414,7 @@ class CondConvResidual(InvertedResidual):
super(CondConvResidual, self).__init__( super(CondConvResidual, self).__init__(
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type, in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, pw_kernel_size=pw_kernel_size, attn_layer=attn_layer, attn_kwargs=attn_kwargs,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
drop_path_rate=drop_path_rate) drop_path_rate=drop_path_rate)
@ -310,7 +437,7 @@ class CondConvResidual(InvertedResidual):
x = self.bn2(x) x = self.bn2(x)
x = self.act2(x) x = self.act2(x)
# Squeeze-and-excitation # Attention
if self.se is not None: if self.se is not None:
x = self.se(x) x = self.se(x)
@ -330,7 +457,7 @@ class EdgeResidual(nn.Module):
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
drop_path_rate=0.): drop_path_rate=0.):
super(EdgeResidual, self).__init__() super(EdgeResidual, self).__init__()
norm_kwargs = norm_kwargs or {} norm_kwargs = norm_kwargs or {}
@ -338,7 +465,6 @@ class EdgeResidual(nn.Module):
mid_chs = make_divisible(fake_in_chs * exp_ratio) mid_chs = make_divisible(fake_in_chs * exp_ratio)
else: else:
mid_chs = make_divisible(in_chs * exp_ratio) mid_chs = make_divisible(in_chs * exp_ratio)
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
@ -347,10 +473,10 @@ class EdgeResidual(nn.Module):
self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation # Attention block (Squeeze-Excitation, ECA, etc)
if has_se: if attn_layer is not None:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) self.se = create_attn(attn_layer, mid_chs, **attn_kwargs)
else: else:
self.se = None self.se = None
@ -378,7 +504,7 @@ class EdgeResidual(nn.Module):
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x) x = self.act1(x)
# Squeeze-and-excitation # Attention
if self.se is not None: if self.se is not None:
x = self.se(x) x = self.se(x)

@ -79,10 +79,18 @@ def _decode_block_str(block_str):
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
attn_layer = None
attn_kwargs = None
if 'se' in options:
attn_layer = 'sev2'
attn_kwargs = dict(se_ratio=float(options['se']))
elif 'eca' in options:
attn_layer = 'ceca'
attn_kwargs = dict(kernel_size=int(options['eca']))
num_repeat = int(options['r']) num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly # each type of block has different valid arguments, fill accordingly
if block_type == 'ir': if block_type == 'ir' or block_type == 'xir':
block_args = dict( block_args = dict(
block_type=block_type, block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']), dw_kernel_size=_parse_ksize(options['k']),
@ -90,20 +98,22 @@ def _decode_block_str(block_str):
pw_kernel_size=pw_kernel_size, pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']), out_chs=int(options['c']),
exp_ratio=float(options['e']), exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None, attn_layer=attn_layer,
attn_kwargs=attn_kwargs,
stride=int(options['s']), stride=int(options['s']),
act_layer=act_layer, act_layer=act_layer,
noskip=noskip, noskip=noskip,
) )
if 'cc' in options: if 'cc' in options:
block_args['num_experts'] = int(options['cc']) block_args['num_experts'] = int(options['cc'])
elif block_type == 'ds' or block_type == 'dsa': elif block_type == 'ds' or block_type == 'dsa' or block_type == 'xds':
block_args = dict( block_args = dict(
block_type=block_type, block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']), dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size, pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']), out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None, attn_layer=attn_layer,
attn_kwargs=attn_kwargs,
stride=int(options['s']), stride=int(options['s']),
act_layer=act_layer, act_layer=act_layer,
pw_act=block_type == 'dsa', pw_act=block_type == 'dsa',
@ -117,7 +127,8 @@ def _decode_block_str(block_str):
out_chs=int(options['c']), out_chs=int(options['c']),
exp_ratio=float(options['e']), exp_ratio=float(options['e']),
fake_in_chs=fake_in_chs, fake_in_chs=fake_in_chs,
se_ratio=float(options['se']) if 'se' in options else None, attn_layer=attn_layer,
attn_kwargs=attn_kwargs,
stride=int(options['s']), stride=int(options['s']),
act_layer=act_layer, act_layer=act_layer,
noskip=noskip, noskip=noskip,
@ -201,7 +212,7 @@ class EfficientNetBuilder:
""" """
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=None, se_kwargs=None, output_stride=32, pad_type='', act_layer=None, attn_layer=None, attn_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='', norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='',
verbose=False): verbose=False):
self.channel_multiplier = channel_multiplier self.channel_multiplier = channel_multiplier
@ -210,7 +221,8 @@ class EfficientNetBuilder:
self.output_stride = output_stride self.output_stride = output_stride
self.pad_type = pad_type self.pad_type = pad_type
self.act_layer = act_layer self.act_layer = act_layer
self.se_kwargs = se_kwargs self.attn_layer = attn_layer
self.attn_kwargs = attn_kwargs
self.norm_layer = norm_layer self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs self.norm_kwargs = norm_kwargs
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
@ -220,6 +232,7 @@ class EfficientNetBuilder:
# state updated during build, consumed by model # state updated during build, consumed by model
self.in_chs = None self.in_chs = None
self.x_count = 0
self.features = OrderedDict() self.features = OrderedDict()
def _round_channels(self, chs): def _round_channels(self, chs):
@ -239,35 +252,45 @@ class EfficientNetBuilder:
# block act fn overrides the model default # block act fn overrides the model default
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
assert ba['act_layer'] is not None assert ba['act_layer'] is not None
if 'attn_layer' in ba:
assert'attn_kwargs' in ba # block args should have both or neither
# per-block attn layer overrides model default
ba['attn_layer'] = ba['attn_layer'] if ba['attn_layer'] is not None else self.attn_layer
if self.attn_kwargs is not None:
# merge per-block attn kwargs with model if both exist
if ba['attn_kwargs'] is None:
ba['attn_kwargs'] = self.attn_kwargs
else:
ba['attn_kwargs'].update(self.attn_kwargs)
ba['drop_path_rate'] = drop_path_rate
if bt == 'ir': if bt == 'ir':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
if ba.get('num_experts', 0) > 0: if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba) block = CondConvResidual(**ba)
else: else:
block = InvertedResidual(**ba) block = InvertedResidual(**ba)
elif bt == 'xir':
ba['pad_shift'] = self.x_count
block = XInvertedResidual(**ba)
self.x_count = (self.x_count + 1) % 4
elif bt == 'ds' or bt == 'dsa': elif bt == 'ds' or bt == 'dsa':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
block = DepthwiseSeparableConv(**ba) block = DepthwiseSeparableConv(**ba)
elif bt == 'xds':
ba['pad_shift'] = self.x_count
block = XDepthwiseSeparableConv(**ba)
self.x_count = (self.x_count + 1) % 4
elif bt == 'er': elif bt == 'er':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
if self.verbose:
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
block = EdgeResidual(**ba) block = EdgeResidual(**ba)
elif bt == 'cn': elif bt == 'cn':
if self.verbose: del ba['drop_path_rate']
logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)))
block = ConvBnAct(**ba) block = ConvBnAct(**ba)
else: else:
assert False, 'Uknkown block type (%s) while building model.' % bt assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs'] # update in_chs for arg of next block self.in_chs = ba['out_chs'] # update in_chs for arg of next block
if self.verbose:
logging.info(' {} {}, Args: {}'.format(block.__class__.__name__, block_idx, str(ba)))
return block return block
def __call__(self, in_chs, model_block_args): def __call__(self, in_chs, model_block_args):
@ -359,7 +382,7 @@ class EfficientNetBuilder:
return stages return stages
def _init_weight_goog(m, n='', fix_group_fanout=False): def _init_weight_goog(m, n='', fix_group_fanout=True):
""" Weight initialization as per Tensorflow official implementations. """ Weight initialization as per Tensorflow official implementations.
Args: Args:

@ -11,7 +11,7 @@ import torch.nn.functional as F
from .registry import register_model from .registry import register_model
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SEModule from .layers import SqueezeExcite
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .resnet import ResNet, Bottleneck, BasicBlock from .resnet import ResNet, Bottleneck, BasicBlock
@ -321,7 +321,7 @@ def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kw
default_cfg = default_cfgs['gluon_seresnext50_32x4d'] default_cfg = default_cfgs['gluon_seresnext50_32x4d']
model = ResNet( model = ResNet(
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs) num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SqueezeExcite), **kwargs)
model.default_cfg = default_cfg model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
@ -335,7 +335,7 @@ def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k
default_cfg = default_cfgs['gluon_seresnext101_32x4d'] default_cfg = default_cfgs['gluon_seresnext101_32x4d']
model = ResNet( model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs) num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SqueezeExcite), **kwargs)
model.default_cfg = default_cfg model.default_cfg = default_cfg
if pretrained: if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans) load_pretrained(model, default_cfg, num_classes, in_chans)
@ -347,7 +347,7 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k
"""Constructs a SEResNeXt-101-64x4d model. """Constructs a SEResNeXt-101-64x4d model.
""" """
default_cfg = default_cfgs['gluon_seresnext101_64x4d'] default_cfg = default_cfgs['gluon_seresnext101_64x4d']
block_args = dict(attn_layer=SEModule) block_args = dict(attn_layer=SqueezeExcite)
model = ResNet( model = ResNet(
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4,
num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)
@ -362,7 +362,7 @@ def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs an SENet-154 model. """Constructs an SENet-154 model.
""" """
default_cfg = default_cfgs['gluon_senet154'] default_cfg = default_cfgs['gluon_senet154']
block_args = dict(attn_layer=SEModule) block_args = dict(attn_layer=SqueezeExcite)
model = ResNet( model = ResNet(
Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3, Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3,
block_reduce_first=2, num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) block_reduce_first=2, num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs)

@ -7,8 +7,8 @@ from .cond_conv2d import CondConv2d, get_condconv_initializer
from .create_conv2d import create_conv2d from .create_conv2d import create_conv2d
from .create_attn import create_attn from .create_attn import create_attn
from .selective_kernel import SelectiveKernelConv from .selective_kernel import SelectiveKernelConv
from .se import SEModule from .se import SqueezeExcite, SqueezeExciteV2
from .eca import EcaModule, CecaModule from .eca import EfficientChannelAttn, CircularEfficientChannelAttn
from .activations import * from .activations import *
from .adaptive_avgmax_pool import \ from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d

@ -75,9 +75,9 @@ class LightSpatialAttn(nn.Module):
return x * x_attn.sigmoid() return x * x_attn.sigmoid()
class CbamModule(nn.Module): class ConvBlockAttn(nn.Module):
def __init__(self, channels, spatial_kernel_size=7): def __init__(self, channels, spatial_kernel_size=7):
super(CbamModule, self).__init__() super(ConvBlockAttn, self).__init__()
self.channel = ChannelAttn(channels) self.channel = ChannelAttn(channels)
self.spatial = SpatialAttn(spatial_kernel_size) self.spatial = SpatialAttn(spatial_kernel_size)
@ -87,9 +87,9 @@ class CbamModule(nn.Module):
return x return x
class LightCbamModule(nn.Module): class LightConvBlockAttn(nn.Module):
def __init__(self, channels, spatial_kernel_size=7): def __init__(self, channels, spatial_kernel_size=7):
super(LightCbamModule, self).__init__() super(LightConvBlockAttn, self).__init__()
self.channel = LightChannelAttn(channels) self.channel = LightChannelAttn(channels)
self.spatial = LightSpatialAttn(spatial_kernel_size) self.spatial = LightSpatialAttn(spatial_kernel_size)

@ -13,8 +13,8 @@ from .padding import get_padding, pad_same, is_static_pad
def conv2d_same( def conv2d_same(
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1, pad_shift: int = 0):
x = pad_same(x, weight.shape[-2:], stride, dilation) x = pad_same(x, weight.shape[-2:], stride, dilation, pad_shift)
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
@ -23,12 +23,14 @@ class Conv2dSame(nn.Conv2d):
""" """
def __init__(self, in_channels, out_channels, kernel_size, stride=1, def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True): padding=0, dilation=1, groups=1, bias=True, pad_shift=0):
super(Conv2dSame, self).__init__( super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
self.pad_shift = pad_shift
def forward(self, x): def forward(self, x):
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return conv2d_same(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.pad_shift)
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:

@ -3,9 +3,9 @@
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """
import torch import torch
from .se import SEModule from .se import SqueezeExcite, SqueezeExciteV2
from .eca import EcaModule, CecaModule from .eca import EfficientChannelAttn, CircularEfficientChannelAttn
from .cbam import CbamModule, LightCbamModule from .cbam import ConvBlockAttn, LightConvBlockAttn
def create_attn(attn_type, channels, **kwargs): def create_attn(attn_type, channels, **kwargs):
@ -14,20 +14,19 @@ def create_attn(attn_type, channels, **kwargs):
if isinstance(attn_type, str): if isinstance(attn_type, str):
attn_type = attn_type.lower() attn_type = attn_type.lower()
if attn_type == 'se': if attn_type == 'se':
module_cls = SEModule module_cls = SqueezeExcite
elif attn_type == 'sev2':
module_cls = SqueezeExciteV2
elif attn_type == 'eca': elif attn_type == 'eca':
module_cls = EcaModule module_cls = EfficientChannelAttn
elif attn_type == 'eca': elif attn_type == 'ceca':
module_cls = CecaModule module_cls = CircularEfficientChannelAttn
elif attn_type == 'cbam': elif attn_type == 'cbam':
module_cls = CbamModule module_cls = ConvBlockAttn
elif attn_type == 'lcbam': elif attn_type == 'lcbam':
module_cls = LightCbamModule module_cls = LightConvBlockAttn
else: else:
assert False, "Invalid attn module (%s)" % attn_type assert False, "Invalid attn module (%s)" % attn_type
elif isinstance(attn_type, bool):
if attn_type:
module_cls = SEModule
else: else:
module_cls = attn_type module_cls = attn_type
if module_cls is not None: if module_cls is not None:

@ -34,11 +34,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
""" """
import math import math
import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
class EcaModule(nn.Module): class EfficientChannelAttn(nn.Module):
"""Constructs an ECA module. """Constructs an ECA module.
Args: Args:
@ -49,8 +50,8 @@ class EcaModule(nn.Module):
(default=None. if channel size not given, use k_size given for kernel size.) (default=None. if channel size not given, use k_size given for kernel size.)
kernel_size: Adaptive selection of kernel size (default=3) kernel_size: Adaptive selection of kernel size (default=3)
""" """
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, gate_fn=None):
super(EcaModule, self).__init__() super(EfficientChannelAttn, self).__init__()
assert kernel_size % 2 == 1 assert kernel_size % 2 == 1
if channels is not None: if channels is not None:
@ -59,20 +60,34 @@ class EcaModule(nn.Module):
self.avg_pool = nn.AdaptiveAvgPool2d(1) self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
self.gate_fn = gate_fn
def forward(self, x): def forward(self, x):
# Feature descriptor on the global spatial information
y = self.avg_pool(x) y = self.avg_pool(x)
# Reshape for convolution y = y.view(x.shape[0], 1, -1) # Reshape 4d -> 3d for 1d convolution
y = y.view(x.shape[0], 1, -1)
# Two different branches of ECA module
y = self.conv(y) y = self.conv(y)
# Multi-scale information fusion y = y.view(x.shape[0], -1, 1, 1) # Back to 4d
y = y.view(x.shape[0], -1, 1, 1).sigmoid() y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y)
return x * y.expand_as(x) return x * y.expand_as(x)
class CecaModule(nn.Module): def padding1d_circular(input, pad):
r"""input: torch.tensor([[[0., 1., 2.],
[3., 4., 5.]]])
pad: (1, 2)
output: tensor([[[2., 0., 1., 2., 0., 1.],
[5., 3., 4., 5., 3., 4.]]])
from: https://github.com/pytorch/pytorch/issues/24504
"""
input = torch.cat([input, input[:, :, 0:pad[-1]]], dim=2)
if pad[-1] == 0 and pad[-2] != 0:
return torch.cat([input[:, :, -(pad[-1] + pad[-2]):], input], dim=2)
else:
return torch.cat([input[:, :, -(pad[-1] + pad[-2]):-pad[-1]], input], dim=2)
class CircularEfficientChannelAttn(nn.Module):
"""Constructs a circular ECA module. """Constructs a circular ECA module.
ECA module where the conv uses circular padding rather than zero padding. ECA module where the conv uses circular padding rather than zero padding.
@ -92,33 +107,28 @@ class CecaModule(nn.Module):
kernel_size: Adaptive selection of kernel size (default=3) kernel_size: Adaptive selection of kernel size (default=3)
""" """
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, gate_fn=None):
super(CecaModule, self).__init__() super(CircularEfficientChannelAttn, self).__init__()
assert kernel_size % 2 == 1 assert kernel_size % 2 == 1
if channels is not None: if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma) t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3) kernel_size = max(t if t % 2 else t + 1, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pytorch conv circular padding mode is buggy as of pytorch 1.4, will implement manually
#pytorch circular padding mode is buggy as of pytorch 1.4 # see https://github.com/pytorch/pytorch/pull/17240
#see https://github.com/pytorch/pytorch/pull/17240 # https://github.com/pytorch/pytorch/issues/24504
p = (kernel_size - 1) // 2
self.padding = (p, p)
#implement manual circular padding self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False) self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
self.padding = (kernel_size - 1) // 2 self.gate_fn = gate_fn
def forward(self, x): def forward(self, x):
# Feature descriptor on the global spatial information
y = self.avg_pool(x) y = self.avg_pool(x)
y = padding1d_circular(y.view(x.shape[0], 1, -1), self.padding) # manual circular padding
# Manually implement circular padding, F.pad does not seemed to be bugged
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
# Two different branches of ECA module
y = self.conv(y) y = self.conv(y)
y = y.view(x.shape[0], -1, 1, 1)
# Multi-scale information fusion y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y)
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
return x * y.expand_as(x) return x * y.expand_as(x)

@ -7,7 +7,7 @@ from torch._six import container_abcs
# From PyTorch internals # From PyTorch internals
def _ntuple(n): def ntuple(n):
def parse(x): def parse(x):
if isinstance(x, container_abcs.Iterable): if isinstance(x, container_abcs.Iterable):
return x return x
@ -15,13 +15,19 @@ def _ntuple(n):
return parse return parse
tup_single = _ntuple(1) tup_single = ntuple(1)
tup_pair = _ntuple(2) tup_pair = ntuple(2)
tup_triple = _ntuple(3) tup_triple = ntuple(3)
tup_quadruple = _ntuple(4) tup_quadruple = ntuple(4)
def make_divisible(v, divisor=8, min_value=None):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v

@ -4,12 +4,17 @@ Hacked together by Ross Wightman
""" """
import math import math
from typing import List from typing import List
from .helpers import ntuple
import torch.nn.functional as F import torch.nn.functional as F
# Calculate symmetric padding for a convolution # Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
if isinstance(kernel_size, (list, tuple)):
stride = ntuple(len(kernel_size))(stride)
dilation = ntuple(len(kernel_size))(dilation)
return [get_padding(k, s, d) for k, s, d in zip(kernel_size, stride, dilation)]
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding return padding
@ -25,9 +30,17 @@ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
# Dynamically pad input x with 'SAME' padding for conv with specified args # Dynamically pad input x with 'SAME' padding for conv with specified args
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)): def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), shift: int = 0):
ih, iw = x.size()[-2:] ih, iw = x.size()[-2:]
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
if pad_h > 0 or pad_w > 0: if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) if shift == 0:
pl = [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] # ul
elif shift == 1:
pl = [pad_w - pad_w // 2, pad_w // 2, pad_h - pad_h // 2, pad_h // 2] # lr
elif shift == 2:
pl = [pad_w - pad_w // 2, pad_w // 2, pad_h // 2, pad_h - pad_h // 2] # ur
else:
pl = [pad_w // 2, pad_w - pad_w // 2, pad_h - pad_h // 2, pad_h // 2] # ll
x = F.pad(x, pl)
return x return x

@ -1,12 +1,22 @@
import torch
from torch import nn as nn from torch import nn as nn
from .helpers import make_divisible
class SEModule(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU): class SqueezeExcite(nn.Module):
super(SEModule, self).__init__() """ Squeeze-and-Excitation module as used in Pytorch SENet, SE-ResNeXt implementations
Args:
channels (int): number of input and output channels
reduction (int, float): divisor for attention (squeezed) channels
act_layer (nn.Module): override the default ReLU activation
"""
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisible_by=1):
super(SqueezeExcite, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduction_channels = max(channels // reduction, 8) reduction_channels = make_divisible(channels // reduction, divisible_by)
self.fc1 = nn.Conv2d( self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True) channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.act = act_layer(inplace=True) self.act = act_layer(inplace=True)
@ -19,3 +29,38 @@ class SEModule(nn.Module):
x_se = self.act(x_se) x_se = self.act(x_se)
x_se = self.fc2(x_se) x_se = self.fc2(x_se)
return x * x_se.sigmoid() return x * x_se.sigmoid()
class SqueezeExciteV2(nn.Module):
""" Squeeze-and-Excitation module as used in EfficientNet, MobileNetV3, related models
Differs from the original SqueezeExcite impl in that:
* reduction is specified as a float multiplier instead of divisor (se_ratio)
* gate function is changeable from sigmoid to alternate (ie hard_sigmoid)
* layer names match those in weights for the EfficientNet/MobileNetV3 families
Args:
channels (int): number of input and output channels
se_ratio (float): multiplier for attention (squeezed) channels
reduced_base_chs (int): specify alternate channel count to base the reduction channels on
act_layer (nn.Module): override the default ReLU activation
gate_fn (callable): override the default gate function
"""
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
act_layer=nn.ReLU, gate_fn=torch.sigmoid, divisible_by=1, **_):
super(SqueezeExciteV2, self).__init__()
self.gate_fn = gate_fn
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisible_by)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
x = x * self.gate_fn(x_se)
return x

@ -30,10 +30,12 @@ def _cfg(url='', **kwargs):
default_cfgs = { default_cfgs = {
'mobilenetv3_large_075': _cfg(url=''), 'mobilenetv3_large_075': _cfg(url='', interoplation='bicubic'),
'mobilenetv3_large_100': _cfg(url=''), 'mobilenetv3_large_100': _cfg(url='', interoplation='bicubic'),
'mobilenetv3_small_075': _cfg(url=''), 'mobilenetv3_small_075': _cfg(url='', interoplation='bicubic'),
'mobilenetv3_small_100': _cfg(url=''), 'mobilenetv3_small_100': _cfg(url='', interoplation='bicubic'),
'mobilenetv3_eca_large': _cfg(url='', interoplation='bicubic'),
'xmobilenetv3_large_100': _cfg(url='', interoplation='bicubic'),
'mobilenetv3_rw': _cfg( 'mobilenetv3_rw': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
interpolation='bicubic'), interpolation='bicubic'),
@ -57,7 +59,7 @@ default_cfgs = {
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
} }
_DEBUG = False _DEBUG = True
class MobileNetV3(nn.Module): class MobileNetV3(nn.Module):
@ -72,7 +74,7 @@ class MobileNetV3(nn.Module):
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
super(MobileNetV3, self).__init__() super(MobileNetV3, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
@ -89,7 +91,7 @@ class MobileNetV3(nn.Module):
# Middle stages (IR/ER/DS Blocks) # Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder( builder = EfficientNetBuilder(
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs, channel_multiplier, 8, None, 32, pad_type, act_layer, attn_layer, attn_kwargs,
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG) norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features self.feature_info = builder.features
@ -148,7 +150,7 @@ class MobileNetV3Features(nn.Module):
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl', def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='', in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., attn_layer=None, attn_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None): norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(MobileNetV3Features, self).__init__() super(MobileNetV3Features, self).__init__()
norm_kwargs = norm_kwargs or {} norm_kwargs = norm_kwargs or {}
@ -169,7 +171,7 @@ class MobileNetV3Features(nn.Module):
# Middle stages (IR/ER/DS Blocks) # Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder( builder = EfficientNetBuilder(
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs, channel_multiplier, 8, None, output_stride, pad_type, act_layer, attn_layer, attn_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.blocks = nn.Sequential(*builder(self._in_chs, block_args))
self.feature_info = builder.features # builder provides info about feature channels for each block self.feature_info = builder.features # builder provides info about feature channels for each block
@ -256,7 +258,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
act_layer=HardSwish, act_layer=HardSwish,
se_kwargs=dict(gate_fn=hard_sigmoid, reduce_mid=True, divisor=1), attn_kwargs=dict(gate_fn=hard_sigmoid, reduce_mid=True, divisor=1),
**kwargs, **kwargs,
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
@ -352,7 +354,179 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
act_layer=act_layer, act_layer=act_layer,
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), attn_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisible_by=8),
**kwargs,
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_xmobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MobileNet-V3 model.
Ref impl: ?
Paper: https://arxiv.org/abs/1905.02244
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
if 'small' in variant:
num_features = 1024
if 'minimal' in variant:
act_layer = nn.ReLU
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16'],
# stage 1, 56x56 in
['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
# stage 2, 28x28 in
['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
# stage 3, 14x14 in
['ir_r2_k3_s1_e3_c48'],
# stage 4, 14x14in
['ir_r3_k3_s2_e6_c96'],
# stage 6, 7x7 in
['cn_r1_k1_s1_c576'],
]
else:
act_layer = HardSwish
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
# stage 1, 56x56 in
['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
# stage 2, 28x28 in
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
# stage 3, 14x14 in
['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
# stage 4, 14x14in
['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c576'], # hard-swish
]
else:
num_features = 1280
if 'minimal' in variant:
act_layer = nn.ReLU
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16'],
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
# stage 2, 56x56 in
['ir_r3_k3_s2_e3_c40'],
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112'],
# stage 5, 14x14in
['ir_r3_k3_s2_e6_c160'],
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'],
]
else:
act_layer = HardSwish
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_nre'], # relu
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25_nre', 'xir_r2_k5_s2_e3_c40_se0.25_nre'], # relu
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
# stage 4, 14x14in
['xir_r2_k5_s1_e6_c112_se0.25'], # hard-swish
# stage 5, 14x14in
['ir_r1_k5_s2_e6_c160_se0.25', 'xir_r2_k5_s2_e6_c160_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,
stem_size=16,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=act_layer,
attn_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisible_by=8),
**kwargs,
)
model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
return model
def _gen_mobilenet_v3_eca(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MobileNet-V3 model.
Ref impl: ?
Paper: https://arxiv.org/abs/1905.02244
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
if 'small' in variant:
num_features = 1024
act_layer = HardSwish
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16_nre'], # relu
# stage 1, 56x56 in
['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
# stage 2, 28x28 in
['ir_r1_k5_s2_e4_c40', 'ir_r2_k5_s1_e6_c40'], # hard-swish
# stage 3, 14x14 in
['ir_r2_k5_s1_e3_c48'], # hard-swish
# stage 4, 14x14in
['ir_r3_k5_s2_e6_c96'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c576'], # hard-swish
]
else:
num_features = 1280
act_layer = HardSwish
# arch_def = [
# # stage 0, 112x112 in
# ['ds_r1_k3_s1_e1_c16_nre'], # relu
# # stage 1, 112x112 in
# ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# # stage 2, 56x56 in
# ['ir_r3_k5_s2_e3_c40_nre'], # relu
# # stage 3, 28x28 in
# ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
# # stage 4, 14x14in
# ['ir_r2_k3_s1_e6_c112'], # hard-swish
# # stage 5, 14x14in
# ['ir_r3_k5_s2_e6_c160'], # hard-swish
# # stage 6, 7x7 in
# ['cn_r1_k1_s1_c960'], # hard-swish
# ]
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_nre'], # relu
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_eca3_nre'], # relu
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80_eca3', 'ir_r1_k3_s1_e2.5_c80_eca3', 'ir_r2_k3_s1_e2.3_c80_eca3'], # hard-swish
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112_eca5'], # hard-swish
# stage 5, 14x14in
['ir_r3_k5_s2_e6_c160_eca5'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,
stem_size=16,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=act_layer,
#attn_layer='ceca',
attn_kwargs=dict(gate_fn=hard_sigmoid),
**kwargs, **kwargs,
) )
model = _create_model(model_kwargs, default_cfgs[variant], pretrained) model = _create_model(model_kwargs, default_cfgs[variant], pretrained)
@ -382,12 +556,25 @@ def mobilenetv3_small_075(pretrained=False, **kwargs):
@register_model @register_model
def mobilenetv3_small_100(pretrained=False, **kwargs): def mobilenetv3_small_100(pretrained=False, **kwargs):
print(kwargs)
""" MobileNet V3 """ """ MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
return model return model
@register_model
def mobilenetv3_eca_large(pretrained=False, **kwargs):
""" MobileNet V3 """
model = _gen_mobilenet_v3_eca('mobilenetv3_eca_large', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def xmobilenetv3_large_100(pretrained=False, **kwargs):
""" MobileNet V3 """
model = _gen_xmobilenet_v3('xmobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model @register_model
def mobilenetv3_rw(pretrained=False, **kwargs): def mobilenetv3_rw(pretrained=False, **kwargs):
""" MobileNet V3 """ """ MobileNet V3 """

@ -11,7 +11,7 @@ import torch.nn.functional as F
from .resnet import ResNet from .resnet import ResNet
from .registry import register_model from .registry import register_model
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import SEModule from .layers import SqueezeExcite
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = [] __all__ = []

Loading…
Cancel
Save