add anti-aliasing for `ir` block, mobnet-v3

pull/603/merge^2
Rahul Somani 4 years ago
parent 9cc7dda6e5
commit c2abb2c03d

@ -218,12 +218,13 @@ class InvertedResidual(nn.Module):
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
conv_kwargs=None, drop_path_rate=0.): conv_kwargs=None, drop_path_rate=0., aa_layer=None):
super(InvertedResidual, self).__init__() super(InvertedResidual, self).__init__()
norm_kwargs = norm_kwargs or {} norm_kwargs = norm_kwargs or {}
conv_kwargs = conv_kwargs or {} conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio) mid_chs = make_divisible(in_chs * exp_ratio)
has_se = se_ratio is not None and se_ratio > 0. has_se = se_ratio is not None and se_ratio > 0.
use_aa = aa_layer is not None and stride == 2
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
@ -234,10 +235,11 @@ class InvertedResidual(nn.Module):
# Depth-wise convolution # Depth-wise convolution
self.conv_dw = create_conv2d( self.conv_dw = create_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, mid_chs, mid_chs, dw_kernel_size, stride=1 if use_aa else stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs) padding=pad_type, depthwise=True, **conv_kwargs)
self.bn2 = norm_layer(mid_chs, **norm_kwargs) self.bn2 = norm_layer(mid_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
self.aa = aa_layer(mid_chs, stride=stride) if use_aa else None
# Squeeze-and-excitation # Squeeze-and-excitation
if has_se: if has_se:
@ -269,6 +271,8 @@ class InvertedResidual(nn.Module):
x = self.conv_dw(x) x = self.conv_dw(x)
x = self.bn2(x) x = self.bn2(x)
x = self.act2(x) x = self.act2(x)
if self.aa is not None:
x = self.aa(x)
# Squeeze-and-excitation # Squeeze-and-excitation
if self.se is not None: if self.se is not None:

@ -221,7 +221,7 @@ class EfficientNetBuilder:
""" """
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=None, se_kwargs=None, output_stride=32, pad_type='', act_layer=None, se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='', norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., aa_layer=None, feature_location='',
verbose=False): verbose=False):
self.channel_multiplier = channel_multiplier self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor self.channel_divisor = channel_divisor
@ -233,6 +233,7 @@ class EfficientNetBuilder:
self.norm_layer = norm_layer self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs self.norm_kwargs = norm_kwargs
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.aa_layer = aa_layer
if feature_location == 'depthwise': if feature_location == 'depthwise':
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
_logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'") _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
@ -269,6 +270,8 @@ class EfficientNetBuilder:
if ba.get('num_experts', 0) > 0: if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba) block = CondConvResidual(**ba)
else: else:
# FIXME: `aa_layer` only impl for `InvertedResidual`. Add `CondConvResidual`?
ba['aa_layer'] = self.aa_layer
block = InvertedResidual(**ba) block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa': elif bt == 'ds' or bt == 'dsa':
ba['drop_path_rate'] = drop_path_rate ba['drop_path_rate'] = drop_path_rate

@ -18,7 +18,7 @@ from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_la
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from .features import FeatureInfo, FeatureHooks from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, default_cfg_for_features from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid from .layers import SelectAdaptivePool2d, Linear, BlurPool2d, create_conv2d, get_act_fn, hard_sigmoid
from .registry import register_model from .registry import register_model
__all__ = ['MobileNetV3'] __all__ = ['MobileNetV3']
@ -85,7 +85,7 @@ class MobileNetV3(nn.Module):
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg', aa_layer=None):
super(MobileNetV3, self).__init__() super(MobileNetV3, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
@ -101,7 +101,7 @@ class MobileNetV3(nn.Module):
# Middle stages (IR/ER/DS Blocks) # Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder( builder = EfficientNetBuilder(
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs, channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG) norm_layer, norm_kwargs, drop_path_rate, aa_layer, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features self.feature_info = builder.features
head_chs = builder.in_chs head_chs = builder.in_chs
@ -160,7 +160,7 @@ class MobileNetV3Features(nn.Module):
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='', in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None): norm_layer=nn.BatchNorm2d, norm_kwargs=None, aa_layer=None,):
super(MobileNetV3Features, self).__init__() super(MobileNetV3Features, self).__init__()
norm_kwargs = norm_kwargs or {} norm_kwargs = norm_kwargs or {}
self.drop_rate = drop_rate self.drop_rate = drop_rate
@ -174,7 +174,7 @@ class MobileNetV3Features(nn.Module):
# Middle stages (IR/ER/DS Blocks) # Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder( builder = EfficientNetBuilder(
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs, channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) norm_layer, norm_kwargs, drop_path_rate, aa_layer, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices) self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
@ -405,6 +405,20 @@ def mobilenetv3_small_100(pretrained=False, **kwargs):
return model return model
@register_model
def mobilenetv3_large_075_aa(pretrained=False, aa_layer=BlurPool2d, **kwargs):
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_large_075', 1.0, pretrained=pretrained, aa_layer=aa_layer, **kwargs)
return model
@register_model
def mobilenetv3_large_100_aa(pretrained=False, aa_layer=BlurPool2d, **kwargs):
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, aa_layer=aa_layer, **kwargs)
return model
@register_model @register_model
def mobilenetv3_rw(pretrained=False, **kwargs): def mobilenetv3_rw(pretrained=False, **kwargs):
""" MobileNet V3 """ """ MobileNet V3 """

Loading…
Cancel
Save