From c2abb2c03d15a719ca1f803032d530bff801a2c6 Mon Sep 17 00:00:00 2001 From: Rahul Somani Date: Sat, 1 May 2021 11:50:43 +0530 Subject: [PATCH 1/4] add anti-aliasing for `ir` block, mobnet-v3 --- timm/models/efficientnet_blocks.py | 8 ++++++-- timm/models/efficientnet_builder.py | 5 ++++- timm/models/mobilenetv3.py | 24 +++++++++++++++++++----- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 114533cf..e81bb75b 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -218,12 +218,13 @@ class InvertedResidual(nn.Module): stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, - conv_kwargs=None, drop_path_rate=0.): + conv_kwargs=None, drop_path_rate=0., aa_layer=None): super(InvertedResidual, self).__init__() norm_kwargs = norm_kwargs or {} conv_kwargs = conv_kwargs or {} mid_chs = make_divisible(in_chs * exp_ratio) has_se = se_ratio is not None and se_ratio > 0. + use_aa = aa_layer is not None and stride == 2 self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_path_rate = drop_path_rate @@ -234,10 +235,11 @@ class InvertedResidual(nn.Module): # Depth-wise convolution self.conv_dw = create_conv2d( - mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, + mid_chs, mid_chs, dw_kernel_size, stride=1 if use_aa else stride, dilation=dilation, padding=pad_type, depthwise=True, **conv_kwargs) self.bn2 = norm_layer(mid_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) + self.aa = aa_layer(mid_chs, stride=stride) if use_aa else None # Squeeze-and-excitation if has_se: @@ -269,6 +271,8 @@ class InvertedResidual(nn.Module): x = self.conv_dw(x) x = self.bn2(x) x = self.act2(x) + if self.aa is not None: + x = self.aa(x) # Squeeze-and-excitation if self.se is not None: diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index f670aa6c..bb8ce796 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -221,7 +221,7 @@ class EfficientNetBuilder: """ def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, output_stride=32, pad_type='', act_layer=None, se_kwargs=None, - norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='', + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., aa_layer=None, feature_location='', verbose=False): self.channel_multiplier = channel_multiplier self.channel_divisor = channel_divisor @@ -233,6 +233,7 @@ class EfficientNetBuilder: self.norm_layer = norm_layer self.norm_kwargs = norm_kwargs self.drop_path_rate = drop_path_rate + self.aa_layer = aa_layer if feature_location == 'depthwise': # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'") @@ -269,6 +270,8 @@ class EfficientNetBuilder: if ba.get('num_experts', 0) > 0: block = CondConvResidual(**ba) else: + # FIXME: `aa_layer` only impl for `InvertedResidual`. Add `CondConvResidual`? + ba['aa_layer'] = self.aa_layer block = InvertedResidual(**ba) elif bt == 'ds' or bt == 'dsa': ba['drop_path_rate'] = drop_path_rate diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 543b33ea..ce59c5b5 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -18,7 +18,7 @@ from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_la from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights from .features import FeatureInfo, FeatureHooks from .helpers import build_model_with_cfg, default_cfg_for_features -from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid +from .layers import SelectAdaptivePool2d, Linear, BlurPool2d, create_conv2d, get_act_fn, hard_sigmoid from .registry import register_model __all__ = ['MobileNetV3'] @@ -85,7 +85,7 @@ class MobileNetV3(nn.Module): def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., - se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg', aa_layer=None): super(MobileNetV3, self).__init__() self.num_classes = num_classes @@ -101,7 +101,7 @@ class MobileNetV3(nn.Module): # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs, - norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG) + norm_layer, norm_kwargs, drop_path_rate, aa_layer, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features head_chs = builder.in_chs @@ -160,7 +160,7 @@ class MobileNetV3Features(nn.Module): def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None, - norm_layer=nn.BatchNorm2d, norm_kwargs=None): + norm_layer=nn.BatchNorm2d, norm_kwargs=None, aa_layer=None,): super(MobileNetV3Features, self).__init__() norm_kwargs = norm_kwargs or {} self.drop_rate = drop_rate @@ -174,7 +174,7 @@ class MobileNetV3Features(nn.Module): # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs, - norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) + norm_layer, norm_kwargs, drop_path_rate, aa_layer, feature_location=feature_location, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} @@ -405,6 +405,20 @@ def mobilenetv3_small_100(pretrained=False, **kwargs): return model +@register_model +def mobilenetv3_large_075_aa(pretrained=False, aa_layer=BlurPool2d, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_075', 1.0, pretrained=pretrained, aa_layer=aa_layer, **kwargs) + return model + + +@register_model +def mobilenetv3_large_100_aa(pretrained=False, aa_layer=BlurPool2d, **kwargs): + """ MobileNet V3 """ + model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, aa_layer=aa_layer, **kwargs) + return model + + @register_model def mobilenetv3_rw(pretrained=False, **kwargs): """ MobileNet V3 """ From e060223cee7ddd5a6e9c88904e1bce49c2842732 Mon Sep 17 00:00:00 2001 From: Rahul Somani Date: Sat, 1 May 2021 12:10:49 +0530 Subject: [PATCH 2/4] add default configs for aa models --- timm/models/mobilenetv3.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index ce59c5b5..d9b4d755 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -39,6 +39,8 @@ default_cfgs = { 'mobilenetv3_large_100': _cfg( interpolation='bicubic', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'), + 'mobilenetv3_large_075_aa': _cfg(url=''), + 'mobilenetv3_large_100_aa': _cfg(url=''), 'mobilenetv3_large_100_miil': _cfg( interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1), url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_1k_miil_78_0.pth'), @@ -415,7 +417,7 @@ def mobilenetv3_large_075_aa(pretrained=False, aa_layer=BlurPool2d, **kwargs): @register_model def mobilenetv3_large_100_aa(pretrained=False, aa_layer=BlurPool2d, **kwargs): """ MobileNet V3 """ - model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, aa_layer=aa_layer, **kwargs) + model = _gen_mobilenet_v3('mobilenetv3_large_100_aa', 1.0, pretrained=pretrained, aa_layer=aa_layer, **kwargs) return model From c4bb8c7d4a8a9873968bb9d9cc6506d07dc32da3 Mon Sep 17 00:00:00 2001 From: Rahul Somani Date: Mon, 10 May 2021 13:41:04 +0530 Subject: [PATCH 3/4] add model url --- timm/models/mobilenetv3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index d9b4d755..026ce363 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -40,7 +40,7 @@ default_cfgs = { interpolation='bicubic', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'), 'mobilenetv3_large_075_aa': _cfg(url=''), - 'mobilenetv3_large_100_aa': _cfg(url=''), + 'mobilenetv3_large_100_aa': _cfg(url='https://storage.googleapis.com/cinemanet-models/pretrained/mobilenetv3_large_100_aa_224x224_ema.pth'), 'mobilenetv3_large_100_miil': _cfg( interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1), url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_1k_miil_78_0.pth'), From f8ce54893c16967306dfa94a4f0e29d9c4fff0ad Mon Sep 17 00:00:00 2001 From: Rahul Somani Date: Tue, 11 May 2021 23:43:23 +0530 Subject: [PATCH 4/4] add aa stem logic, minimal docstrings --- timm/models/mobilenetv3.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 026ce363..dcccf4ce 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -41,6 +41,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'), 'mobilenetv3_large_075_aa': _cfg(url=''), 'mobilenetv3_large_100_aa': _cfg(url='https://storage.googleapis.com/cinemanet-models/pretrained/mobilenetv3_large_100_aa_224x224_ema.pth'), + 'mobilenetv3_large_100_aa_stem': _cfg(url=''), 'mobilenetv3_large_100_miil': _cfg( interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1), url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_1k_miil_78_0.pth'), @@ -87,7 +88,7 @@ class MobileNetV3(nn.Module): def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., - se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg', aa_layer=None): + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg', aa_layer=None, aa_stem=None): super(MobileNetV3, self).__init__() self.num_classes = num_classes @@ -97,6 +98,7 @@ class MobileNetV3(nn.Module): # Stem stem_size = round_channels(stem_size, channel_multiplier) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem_aa = aa_stem(in_chans) if aa_stem else None self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) @@ -135,6 +137,8 @@ class MobileNetV3(nn.Module): def forward_features(self, x): x = self.conv_stem(x) + if self.conv_stem_aa is not None: + x = self.conv_stem_aa(x) x = self.bn1(x) x = self.act1(x) x = self.blocks(x) @@ -416,10 +420,17 @@ def mobilenetv3_large_075_aa(pretrained=False, aa_layer=BlurPool2d, **kwargs): @register_model def mobilenetv3_large_100_aa(pretrained=False, aa_layer=BlurPool2d, **kwargs): - """ MobileNet V3 """ + """ MobileNet V3 w/ Blur Pooling of IR Blocks """ model = _gen_mobilenet_v3('mobilenetv3_large_100_aa', 1.0, pretrained=pretrained, aa_layer=aa_layer, **kwargs) return model +@register_model +def mobilenetv3_large_100_aa_stem(pretrained=False, aa_layer=BlurPool2d, aa_stem=BlurPool2d, **kwargs): + """ MobileNet V3 w/ Blur Pooling of IR Blocks & Conv Stem """ + model = _gen_mobilenet_v3('mobilenetv3_large_100_aa_stem', 1.0, pretrained=pretrained, + aa_layer=aa_layer, aa_stem=aa_stem, **kwargs) + return model + @register_model def mobilenetv3_rw(pretrained=False, **kwargs):