From c4f482a08b369e00a27bbfd266781c98fde45016 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 14 May 2021 15:50:00 -0700 Subject: [PATCH] EfficientNetV2 official impl w/ weights ported from TF. Cleanup/refactor of related EfficientNet classes and models. --- timm/models/efficientnet.py | 463 +++++++++++++++++++++++----- timm/models/efficientnet_blocks.py | 232 +++++--------- timm/models/efficientnet_builder.py | 110 ++++--- timm/models/ghostnet.py | 5 +- timm/models/hardcorenas.py | 22 +- timm/models/layers/helpers.py | 6 +- timm/models/mobilenetv3.py | 67 ++-- 7 files changed, 587 insertions(+), 318 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 0c414d50..a64adde6 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -2,6 +2,9 @@ An implementation of EfficienNet that covers variety of related models with efficient architectures: +* EfficientNet-V2 + - `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + * EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent weight ports) - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 @@ -22,23 +25,26 @@ An implementation of EfficienNet that covers variety of related models with effi * And likely more... -Hacked together by / Copyright 2020 Ross Wightman +Hacked together by / Copyright 2021 Ross Wightman """ +from functools import partial +from typing import List + import torch import torch.nn as nn import torch.nn.functional as F -from typing import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights +from .efficientnet_blocks import SqueezeExcite +from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ + round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .features import FeatureInfo, FeatureHooks from .helpers import build_model_with_cfg, default_cfg_for_features from .layers import create_conv2d, create_classifier from .registry import register_model -__all__ = ['EfficientNet'] +__all__ = ['EfficientNet', 'EfficientNetFeatures'] def _cfg(url='', **kwargs): @@ -149,9 +155,20 @@ default_cfgs = { url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), - 'efficientnet_v2s': _cfg( + 'efficientnetv2_rw_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth', - input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), # FIXME WIP + input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), + + 'efficientnetv2_s': _cfg( + url='', + input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0), + 'efficientnetv2_m': _cfg( + url='', + input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0), + 'efficientnetv2_l': _cfg( + url='', + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'tf_efficientnet_b0': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', @@ -298,6 +315,58 @@ default_cfgs = { mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'), + 'tf_efficientnetv2_s': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'tf_efficientnetv2_l': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnetv2_s_21kft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21kft1k-d7dafa41.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m_21kft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21kft1k-bf41664a.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'tf_efficientnetv2_l_21kft1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21kft1k-60127a9d.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnetv2_s_21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0), + 'tf_efficientnetv2_m_21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + 'tf_efficientnetv2_l_21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + + 'tf_efficientnetv2_b0': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth', + input_size=(3, 192, 192), test_input_size=(3, 224, 224), pool_size=(6, 6)), + 'tf_efficientnetv2_b1': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b1-be6e41b0.pth', + input_size=(3, 192, 192), test_input_size=(3, 240, 240), pool_size=(6, 6), crop_pct=0.882), + 'tf_efficientnetv2_b2': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b2-847de54e.pth', + input_size=(3, 208, 208), test_input_size=(3, 260, 260), pool_size=(7, 7), crop_pct=0.890), + 'tf_efficientnetv2_b3': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b3-57773f13.pth', + input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904), + 'mixnet_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'), 'mixnet_m': _cfg( @@ -316,13 +385,12 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'), } -_DEBUG = False - class EfficientNet(nn.Module): """ (Generic) EfficientNet A flexible and performant PyTorch implementation of efficient network architectures, including: + * EfficientNet-V2 Small, Medium, Large & B0-B3 * EfficientNet B0-B8, L2 * EfficientNet-EdgeTPU * EfficientNet-CondConv @@ -333,35 +401,35 @@ class EfficientNet(nn.Module): """ - def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, - channel_multiplier=1.0, channel_divisor=8, channel_min=None, - output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., - se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): + def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, fix_stem=False, + output_stride=32, pad_type='', round_chs_fn=round_channels, act_layer=None, norm_layer=None, + se_layer=None, drop_rate=0., drop_path_rate=0., global_pool='avg'): super(EfficientNet, self).__init__() - norm_kwargs = norm_kwargs or {} - + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite self.num_classes = num_classes self.num_features = num_features self.drop_rate = drop_rate # Stem if not fix_stem: - stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + stem_size = round_chs_fn(stem_size) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) - self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.bn1 = norm_layer(stem_size) self.act1 = act_layer(inplace=True) # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( - channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, - norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG) + output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, + act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features head_chs = builder.in_chs # Head + Pooling self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type) - self.bn2 = norm_layer(self.num_features, **norm_kwargs) + self.bn2 = norm_layer(self.num_features) self.act2 = act_layer(inplace=True) self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) @@ -408,25 +476,27 @@ class EfficientNetFeatures(nn.Module): and object detection models. """ - def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', - in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - output_stride=32, pad_type='', fix_stem=False, act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., - se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None): + def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, + stem_size=32, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels, + act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): super(EfficientNetFeatures, self).__init__() - norm_kwargs = norm_kwargs or {} + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite self.drop_rate = drop_rate # Stem if not fix_stem: - stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + stem_size = round_chs_fn(stem_size) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) - self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.bn1 = norm_layer(stem_size) self.act1 = act_layer(inplace=True) # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( - channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, - norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) + output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, + act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate, + feature_location=feature_location) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} @@ -505,8 +575,8 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs) model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=32, - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -541,8 +611,8 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs) model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=32, - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -570,8 +640,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=8, - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -593,13 +663,14 @@ def _gen_mobilenet_v2( ['ir_r3_k3_s2_e6_c160'], ['ir_r1_k3_s1_e6_c320'], ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), - num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None), + num_features=1280 if fix_stem_head else round_chs_fn(1280), stem_size=32, fix_stem=fix_stem_head, - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), + round_chs_fn=round_chs_fn, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'relu6'), **kwargs ) @@ -629,8 +700,8 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): block_args=decode_arch_def(arch_def), stem_size=16, num_features=1984, # paper suggests this, but is not 100% clear - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -664,8 +735,8 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): model_kwargs = dict( block_args=decode_arch_def(arch_def), stem_size=32, - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -705,13 +776,14 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre ['ir_r4_k5_s2_e6_c192_se0.25'], ['ir_r1_k3_s1_e6_c320_se0.25'], ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier), - num_features=round_channels(1280, channel_multiplier, 8, None), + num_features=round_chs_fn(1280), stem_size=32, - channel_multiplier=channel_multiplier, + round_chs_fn=round_chs_fn, act_layer=resolve_act_layer(kwargs, 'swish'), - norm_kwargs=resolve_bn_args(kwargs), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs, ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -734,12 +806,13 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0 ['ir_r4_k5_s1_e8_c144'], ['ir_r2_k5_s2_e8_c192'], ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier), - num_features=round_channels(1280, channel_multiplier, 8, None), + num_features=round_chs_fn(1280), stem_size=32, - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), + round_chs_fn=round_chs_fn, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'relu'), **kwargs, ) @@ -764,12 +837,13 @@ def _gen_efficientnet_condconv( ] # NOTE unlike official impl, this one uses `cc` option where x is the base number of experts for each stage and # the expert_multiplier increases that on a per-model basis as with depth/channel multipliers + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), - num_features=round_channels(1280, channel_multiplier, 8, None), + num_features=round_chs_fn(1280), stem_size=32, - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), + round_chs_fn=round_chs_fn, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'swish'), **kwargs, ) @@ -809,45 +883,137 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0 num_features=1280, stem_size=32, fix_stem=True, - channel_multiplier=channel_multiplier, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), act_layer=resolve_act_layer(kwargs, 'relu6'), - norm_kwargs=resolve_bn_args(kwargs), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs, ) model = _create_effnet(variant, pretrained, **model_kwargs) return model -def _gen_efficientnet_v2s(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): - """ Creates an EfficientNet-V2s model - - NOTE: this is a preliminary definition based on paper, awaiting official code release for details - and weights +def _gen_efficientnetv2_base( + variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 base model - Ref impl: + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 """ + arch_def = [ + ['cn_r1_k3_s1_e1_c16_skip'], + ['er_r2_k3_s2_e4_c32'], + ['er_r2_k3_s2_e4_c48'], + ['ir_r3_k3_s2_e4_c96_se0.25'], + ['ir_r5_k3_s1_e6_c112_se0.25'], + ['ir_r8_k3_s2_e6_c192_se0.25'], + ] + round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_chs_fn(1280), + stem_size=32, + round_chs_fn=round_chs_fn, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + +def _gen_efficientnetv2_s( + variant, channel_multiplier=1.0, depth_multiplier=1.0, rw=False, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Small model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + + NOTE: `rw` flag sets up 'small' variant to behave like my initial v2 small model, + before ref the impl was released. + """ arch_def = [ - # FIXME it's not clear if the FusedMBConv layers have SE enabled for the Small variant, - # Table 4 suggests no. 23.94M params w/o, 23.98 with which is closer to 24M. - # ['er_r2_k3_s1_e1_c24_se0.25'], - # ['er_r4_k3_s2_e4_c48_se0.25'], - # ['er_r4_k3_s2_e4_c64_se0.25'], - ['er_r2_k3_s1_e1_c24'], + ['cn_r2_k3_s1_e1_c24_skip'], ['er_r4_k3_s2_e4_c48'], ['er_r4_k3_s2_e4_c64'], ['ir_r6_k3_s2_e4_c128_se0.25'], ['ir_r9_k3_s1_e6_c160_se0.25'], - ['ir_r15_k3_s2_e6_c272_se0.25'], + ['ir_r15_k3_s2_e6_c256_se0.25'], + ] + num_features = 1280 + if rw: + # my original variant, based on paper figure differs from the official release + arch_def[0] = ['er_r2_k3_s1_e1_c24'] + arch_def[-1] = ['ir_r15_k3_s2_e6_c272_se0.25'] + num_features = 1792 + + round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_chs_fn(num_features), + stem_size=24, + round_chs_fn=round_chs_fn, + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Medium model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + + arch_def = [ + ['cn_r3_k3_s1_e1_c24_skip'], + ['er_r5_k3_s2_e4_c48'], + ['er_r5_k3_s2_e4_c80'], + ['ir_r7_k3_s2_e4_c160_se0.25'], + ['ir_r14_k3_s1_e6_c176_se0.25'], + ['ir_r18_k3_s2_e6_c304_se0.25'], + ['ir_r5_k3_s1_e6_c512_se0.25'], ] + model_kwargs = dict( block_args=decode_arch_def(arch_def, depth_multiplier), - num_features=round_channels(1792, channel_multiplier, 8, None), + num_features=1280, stem_size=24, - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), - act_layer=resolve_act_layer(kwargs, 'silu'), # FIXME this is an assumption, paper does not mention + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), + **kwargs, + ) + model = _create_effnet(variant, pretrained, **model_kwargs) + return model + + +def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """ Creates an EfficientNet-V2 Large model + + Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 + Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298 + """ + + arch_def = [ + ['cn_r4_k3_s1_e1_c32_skip'], + ['er_r7_k3_s2_e4_c64'], + ['er_r7_k3_s2_e4_c96'], + ['ir_r10_k3_s2_e4_c192_se0.25'], + ['ir_r19_k3_s1_e6_c224_se0.25'], + ['ir_r25_k3_s2_e6_c384_se0.25'], + ['ir_r7_k3_s1_e6_c640_se0.25'], + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=1280, + stem_size=32, + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + act_layer=resolve_act_layer(kwargs, 'silu'), **kwargs, ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -879,8 +1045,8 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): block_args=decode_arch_def(arch_def), num_features=1536, stem_size=16, - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -912,8 +1078,8 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), num_features=1536, stem_size=24, - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -1290,13 +1456,35 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs): @register_model -def efficientnet_v2s(pretrained=False, **kwargs): +def efficientnetv2_rw_s(pretrained=False, **kwargs): + """ EfficientNet-V2 Small. + NOTE: This is my initial (pre official code release) w/ some differences. + See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding + """ + model = _gen_efficientnetv2_s('efficientnetv2_rw_s', rw=True, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_s(pretrained=False, **kwargs): """ EfficientNet-V2 Small. """ - model = _gen_efficientnet_v2s( - 'efficientnet_v2s', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + model = _gen_efficientnetv2_s('efficientnetv2_s', pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnetv2_m(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium. """ + model = _gen_efficientnetv2_m('efficientnetv2_m', pretrained=pretrained, **kwargs) return model +@register_model +def efficientnetv2_l(pretrained=False, **kwargs): + """ EfficientNet-V2 Large. """ + model = _gen_efficientnetv2_l('efficientnetv2_l', pretrained=pretrained, **kwargs) + return model + @register_model def tf_efficientnet_b0(pretrained=False, **kwargs): @@ -1708,6 +1896,127 @@ def tf_efficientnet_lite4(pretrained=False, **kwargs): return model + +@register_model +def tf_efficientnetv2_s(pretrained=False, **kwargs): + """ EfficientNet-V2 Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_s('tf_efficientnetv2_s', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_m(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_m('tf_efficientnetv2_m', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_l(pretrained=False, **kwargs): + """ EfficientNet-V2 Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_l('tf_efficientnetv2_l', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_s_21kft1k(pretrained=False, **kwargs): + """ EfficientNet-V2 Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_s('tf_efficientnetv2_s_21kft1k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_m_21kft1k(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_m('tf_efficientnetv2_m_21kft1k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_l_21kft1k(pretrained=False, **kwargs): + """ EfficientNet-V2 Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_l('tf_efficientnetv2_l_21kft1k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_s_21k(pretrained=False, **kwargs): + """ EfficientNet-V2 Small w/ ImageNet-21k pretrained weights. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_s('tf_efficientnetv2_s_21k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_m_21k(pretrained=False, **kwargs): + """ EfficientNet-V2 Medium w/ ImageNet-21k pretrained weights. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_m('tf_efficientnetv2_m_21k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_l_21k(pretrained=False, **kwargs): + """ EfficientNet-V2 Large w/ ImageNet-21k pretrained weights. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_l('tf_efficientnetv2_l_21k', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b0(pretrained=False, **kwargs): + """ EfficientNet-V2-B0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_base('tf_efficientnetv2_b0', pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b1(pretrained=False, **kwargs): + """ EfficientNet-V2-B1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_base( + 'tf_efficientnetv2_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b2(pretrained=False, **kwargs): + """ EfficientNet-V2-B2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_base( + 'tf_efficientnetv2_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +@register_model +def tf_efficientnetv2_b3(pretrained=False, **kwargs): + """ EfficientNet-V2-B3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnetv2_base( + 'tf_efficientnetv2_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + @register_model def mixnet_s(pretrained=False, **kwargs): """Creates a MixNet Small model. diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 040785f6..83b57beb 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -7,106 +7,34 @@ import torch import torch.nn as nn from torch.nn import functional as F -from .layers import create_conv2d, drop_path, get_act_layer +from .layers import create_conv2d, drop_path, make_divisible from .layers.activations import sigmoid -# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per -# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) -# NOTE: momentum varies btw .99 and .9997 depending on source -# .99 in official TF TPU impl -# .9997 (/w .999 in search space) for paper -BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 -BN_EPS_TF_DEFAULT = 1e-3 -_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) - - -def get_bn_args_tf(): - return _BN_ARGS_TF.copy() - - -def resolve_bn_args(kwargs): - bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} - bn_momentum = kwargs.pop('bn_momentum', None) - if bn_momentum is not None: - bn_args['momentum'] = bn_momentum - bn_eps = kwargs.pop('bn_eps', None) - if bn_eps is not None: - bn_args['eps'] = bn_eps - return bn_args - - -_SE_ARGS_DEFAULT = dict( - gate_fn=sigmoid, - act_layer=None, - reduce_mid=False, - divisor=1) - - -def resolve_se_args(kwargs, in_chs, act_layer=None): - se_kwargs = kwargs.copy() if kwargs is not None else {} - # fill in args that aren't specified with the defaults - for k, v in _SE_ARGS_DEFAULT.items(): - se_kwargs.setdefault(k, v) - # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch - if not se_kwargs.pop('reduce_mid'): - se_kwargs['reduced_base_chs'] = in_chs - # act_layer override, if it remains None, the containing block's act_layer will be used - if se_kwargs['act_layer'] is None: - assert act_layer is not None - se_kwargs['act_layer'] = act_layer - return se_kwargs - - -def resolve_act_layer(kwargs, default='relu'): - act_layer = kwargs.pop('act_layer', default) - if isinstance(act_layer, str): - act_layer = get_act_layer(act_layer) - return act_layer - - -def make_divisible(v, divisor=8, min_value=None): - min_value = min_value or divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): - """Round number of filters based on depth multiplier.""" - if not multiplier: - return channels - channels *= multiplier - return make_divisible(channels, divisor, channel_min) - - -class ChannelShuffle(nn.Module): - # FIXME haven't used yet - def __init__(self, groups): - super(ChannelShuffle, self).__init__() - self.groups = groups - - def forward(self, x): - """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" - N, C, H, W = x.size() - g = self.groups - assert C % g == 0, "Incompatible group size {} for input channel {}".format( - g, C - ) - return ( - x.view(N, g, int(C / g), H, W) - .permute(0, 2, 1, 3, 4) - .contiguous() - .view(N, C, H, W) - ) +__all__ = [ + 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] class SqueezeExcite(nn.Module): - def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, - act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_): + """ Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family + + Args: + in_chs (int): input channels to layer + se_ratio (float): ratio of squeeze reduction + act_layer (nn.Module): activation layer of containing block + gate_fn (Callable): attention gate function + block_in_chs (int): input channels of containing block (for calculating reduction from) + reduce_from_block (bool): calculate reduction from block input channels if True + force_act_layer (nn.Module): override block's activation fn if this is set/bound + divisor (int): make reduction channels divisible by this + """ + + def __init__( + self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid, + block_in_chs=None, reduce_from_block=True, force_act_layer=None, divisor=1): super(SqueezeExcite, self).__init__() - reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + reduced_chs = (block_in_chs or in_chs) if reduce_from_block else in_chs + reduced_chs = make_divisible(reduced_chs * se_ratio, divisor) + act_layer = force_act_layer or act_layer self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) self.act1 = act_layer(inplace=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) @@ -121,13 +49,16 @@ class SqueezeExcite(nn.Module): class ConvBnAct(nn.Module): - def __init__(self, in_chs, out_chs, kernel_size, - stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, - norm_layer=nn.BatchNorm2d, norm_kwargs=None): + """ Conv + Norm Layer + Activation w/ optional skip connection + """ + def __init__( + self, in_chs, out_chs, kernel_size, stride=1, dilation=1, pad_type='', + skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.): super(ConvBnAct, self).__init__() - norm_kwargs = norm_kwargs or {} + self.has_residual = skip and stride == 1 and in_chs == out_chs + self.drop_path_rate = drop_path_rate self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) - self.bn1 = norm_layer(out_chs, **norm_kwargs) + self.bn1 = norm_layer(out_chs) self.act1 = act_layer(inplace=True) def feature_info(self, location): @@ -138,9 +69,14 @@ class ConvBnAct(nn.Module): return info def forward(self, x): + shortcut = x x = self.conv(x) x = self.bn1(x) x = self.act1(x) + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += shortcut return x @@ -149,31 +85,26 @@ class DepthwiseSeparableConv(nn.Module): Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. """ - def __init__(self, in_chs, out_chs, dw_kernel_size=3, - stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, - pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, - norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.): + def __init__( + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', + noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0., + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): super(DepthwiseSeparableConv, self).__init__() - norm_kwargs = norm_kwargs or {} - has_se = se_ratio is not None and se_ratio > 0. + has_se = se_layer is not None and se_ratio > 0. self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv self.drop_path_rate = drop_path_rate self.conv_dw = create_conv2d( in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) - self.bn1 = norm_layer(in_chs, **norm_kwargs) + self.bn1 = norm_layer(in_chs) self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - if has_se: - se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) - self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) - else: - self.se = None + self.se = se_layer(in_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity() self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) - self.bn2 = norm_layer(out_chs, **norm_kwargs) + self.bn2 = norm_layer(out_chs) self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() def feature_info(self, location): @@ -190,8 +121,7 @@ class DepthwiseSeparableConv(nn.Module): x = self.bn1(x) x = self.act1(x) - if self.se is not None: - x = self.se(x) + x = self.se(x) x = self.conv_pw(x) x = self.bn2(x) @@ -214,41 +144,36 @@ class InvertedResidual(nn.Module): * MobileNet-V3 - https://arxiv.org/abs/1905.02244 """ - def __init__(self, in_chs, out_chs, dw_kernel_size=3, - stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, - exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, - se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, - conv_kwargs=None, drop_path_rate=0.): + def __init__( + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', + noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): super(InvertedResidual, self).__init__() - norm_kwargs = norm_kwargs or {} conv_kwargs = conv_kwargs or {} mid_chs = make_divisible(in_chs * exp_ratio) - has_se = se_ratio is not None and se_ratio > 0. + has_se = se_layer is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_path_rate = drop_path_rate # Point-wise expansion self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) - self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.bn1 = norm_layer(mid_chs) self.act1 = act_layer(inplace=True) # Depth-wise convolution self.conv_dw = create_conv2d( mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True, **conv_kwargs) - self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.bn2 = norm_layer(mid_chs) self.act2 = act_layer(inplace=True) # Squeeze-and-excitation - if has_se: - se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) - self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) - else: - self.se = None + self.se = se_layer( + mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity() # Point-wise linear projection self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) - self.bn3 = norm_layer(out_chs, **norm_kwargs) + self.bn3 = norm_layer(out_chs) def feature_info(self, location): if location == 'expansion': # after SE, input to PWL @@ -271,8 +196,7 @@ class InvertedResidual(nn.Module): x = self.act2(x) # Squeeze-and-excitation - if self.se is not None: - x = self.se(x) + x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x) @@ -289,11 +213,10 @@ class InvertedResidual(nn.Module): class CondConvResidual(InvertedResidual): """ Inverted residual block w/ CondConv routing""" - def __init__(self, in_chs, out_chs, dw_kernel_size=3, - stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, - exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, - se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, - num_experts=0, drop_path_rate=0.): + def __init__( + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', + noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): self.num_experts = num_experts conv_kwargs = dict(num_experts=self.num_experts) @@ -301,9 +224,8 @@ class CondConvResidual(InvertedResidual): super(CondConvResidual, self).__init__( in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, - pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, - norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, - drop_path_rate=drop_path_rate) + pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_layer=se_layer, + norm_layer=norm_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate) self.routing_fn = nn.Linear(in_chs, self.num_experts) @@ -325,8 +247,7 @@ class CondConvResidual(InvertedResidual): x = self.act2(x) # Squeeze-and-excitation - if self.se is not None: - x = self.se(x) + x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x, routing_weights) @@ -351,36 +272,32 @@ class EdgeResidual(nn.Module): * EfficientNet-V2 - https://arxiv.org/abs/2104.00298 """ - def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, - stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, - se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, - drop_path_rate=0.): + def __init__( + self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='', + force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, se_ratio=0., + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): super(EdgeResidual, self).__init__() - norm_kwargs = norm_kwargs or {} - if fake_in_chs > 0: - mid_chs = make_divisible(fake_in_chs * exp_ratio) + if force_in_chs > 0: + mid_chs = make_divisible(force_in_chs * exp_ratio) else: mid_chs = make_divisible(in_chs * exp_ratio) - has_se = se_ratio is not None and se_ratio > 0. + has_se = se_layer is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_path_rate = drop_path_rate # Expansion convolution self.conv_exp = create_conv2d( in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, padding=pad_type) - self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.bn1 = norm_layer(mid_chs) self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - if has_se: - se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) - self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) - else: - self.se = None + self.se = SqueezeExcite( + mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity() # Point-wise linear projection self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) - self.bn2 = norm_layer(out_chs, **norm_kwargs) + self.bn2 = norm_layer(out_chs) def feature_info(self, location): if location == 'expansion': # after SE, before PWL @@ -398,8 +315,7 @@ class EdgeResidual(nn.Module): x = self.act1(x) # Squeeze-and-excitation - if self.se is not None: - x = self.se(x) + x = self.se(x) # Point-wise linear projection x = self.conv_pwl(x) diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index f670aa6c..9d5853c7 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -14,13 +14,55 @@ from copy import deepcopy import torch.nn as nn from .efficientnet_blocks import * -from .layers import CondConv2d, get_condconv_initializer +from .layers import CondConv2d, get_condconv_initializer, get_act_layer, make_divisible -__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"] +__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights", + 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT'] _logger = logging.getLogger(__name__) +_DEBUG_BUILDER = False + +# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per +# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) +# NOTE: momentum varies btw .99 and .9997 depending on source +# .99 in official TF TPU impl +# .9997 (/w .999 in search space) for paper +BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) + + +def get_bn_args_tf(): + return _BN_ARGS_TF.copy() + + +def resolve_bn_args(kwargs): + bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +def resolve_act_layer(kwargs, default='relu'): + act_layer = kwargs.pop('act_layer', default) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + return act_layer + + +def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit) + + def _log_info_if(msg, condition): if condition: _logger.info(msg) @@ -63,11 +105,13 @@ def _decode_block_str(block_str): block_type = ops[0] # take the block type off the front ops = ops[1:] options = {} - noskip = False + skip = None for op in ops: # string options being checked on individual basis, combine if they grow if op == 'noskip': - noskip = True + skip = False # force no skip connection + elif op == 'skip': + skip = True # force a skip connection elif op.startswith('n'): # activation fn key = op[0] @@ -94,7 +138,7 @@ def _decode_block_str(block_str): act_layer = options['n'] if 'n' in options else None exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 - fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def + force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def num_repeat = int(options['r']) # each type of block has different valid arguments, fill accordingly @@ -106,10 +150,10 @@ def _decode_block_str(block_str): pw_kernel_size=pw_kernel_size, out_chs=int(options['c']), exp_ratio=float(options['e']), - se_ratio=float(options['se']) if 'se' in options else None, + se_ratio=float(options['se']) if 'se' in options else 0., stride=int(options['s']), act_layer=act_layer, - noskip=noskip, + noskip=skip is False, ) if 'cc' in options: block_args['num_experts'] = int(options['cc']) @@ -119,11 +163,11 @@ def _decode_block_str(block_str): dw_kernel_size=_parse_ksize(options['k']), pw_kernel_size=pw_kernel_size, out_chs=int(options['c']), - se_ratio=float(options['se']) if 'se' in options else None, + se_ratio=float(options['se']) if 'se' in options else 0., stride=int(options['s']), act_layer=act_layer, pw_act=block_type == 'dsa', - noskip=block_type == 'dsa' or noskip, + noskip=block_type == 'dsa' or skip is False, ) elif block_type == 'er': block_args = dict( @@ -132,11 +176,11 @@ def _decode_block_str(block_str): pw_kernel_size=pw_kernel_size, out_chs=int(options['c']), exp_ratio=float(options['e']), - fake_in_chs=fake_in_chs, - se_ratio=float(options['se']) if 'se' in options else None, + force_in_chs=force_in_chs, + se_ratio=float(options['se']) if 'se' in options else 0., stride=int(options['s']), act_layer=act_layer, - noskip=noskip, + noskip=skip is False, ) elif block_type == 'cn': block_args = dict( @@ -145,6 +189,7 @@ def _decode_block_str(block_str): out_chs=int(options['c']), stride=int(options['s']), act_layer=act_layer, + skip=skip is True, ) else: assert False, 'Unknown block type (%s)' % block_type @@ -219,19 +264,14 @@ class EfficientNetBuilder: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py """ - def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - output_stride=32, pad_type='', act_layer=None, se_kwargs=None, - norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='', - verbose=False): - self.channel_multiplier = channel_multiplier - self.channel_divisor = channel_divisor - self.channel_min = channel_min + def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, + act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''): self.output_stride = output_stride self.pad_type = pad_type + self.round_chs_fn = round_chs_fn self.act_layer = act_layer - self.se_kwargs = se_kwargs self.norm_layer = norm_layer - self.norm_kwargs = norm_kwargs + self.se_layer = se_layer self.drop_path_rate = drop_path_rate if feature_location == 'depthwise': # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense @@ -239,45 +279,39 @@ class EfficientNetBuilder: feature_location = 'expansion' self.feature_location = feature_location assert feature_location in ('bottleneck', 'expansion', '') - self.verbose = verbose + self.verbose = _DEBUG_BUILDER # state updated during build, consumed by model self.in_chs = None self.features = [] - def _round_channels(self, chs): - return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) - def _make_block(self, ba, block_idx, block_count): drop_path_rate = self.drop_path_rate * block_idx / block_count bt = ba.pop('block_type') ba['in_chs'] = self.in_chs - ba['out_chs'] = self._round_channels(ba['out_chs']) - if 'fake_in_chs' in ba and ba['fake_in_chs']: - # FIXME this is a hack to work around mismatch in origin impl input filters - ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) - ba['norm_layer'] = self.norm_layer - ba['norm_kwargs'] = self.norm_kwargs + ba['out_chs'] = self.round_chs_fn(ba['out_chs']) + if 'force_in_chs' in ba and ba['force_in_chs']: + # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl + ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs']) ba['pad_type'] = self.pad_type # block act fn overrides the model default ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer assert ba['act_layer'] is not None - if bt == 'ir': + ba['norm_layer'] = self.norm_layer + if bt != 'cn': + ba['se_layer'] = self.se_layer ba['drop_path_rate'] = drop_path_rate - ba['se_kwargs'] = self.se_kwargs + + if bt == 'ir': _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) if ba.get('num_experts', 0) > 0: block = CondConvResidual(**ba) else: block = InvertedResidual(**ba) elif bt == 'ds' or bt == 'dsa': - ba['drop_path_rate'] = drop_path_rate - ba['se_kwargs'] = self.se_kwargs _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose) block = DepthwiseSeparableConv(**ba) elif bt == 'er': - ba['drop_path_rate'] = drop_path_rate - ba['se_kwargs'] = self.se_kwargs _log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose) block = EdgeResidual(**ba) elif bt == 'cn': @@ -285,8 +319,8 @@ class EfficientNetBuilder: block = ConvBnAct(**ba) else: assert False, 'Uknkown block type (%s) while building model.' % bt - self.in_chs = ba['out_chs'] # update in_chs for arg of next block + self.in_chs = ba['out_chs'] # update in_chs for arg of next block return block def __call__(self, in_chs, model_block_args): diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 358fb4c7..c132142a 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -13,8 +13,8 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid -from .efficientnet_blocks import SqueezeExcite, ConvBnAct, make_divisible +from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid, make_divisible +from .efficientnet_blocks import SqueezeExcite, ConvBnAct from .helpers import build_model_with_cfg from .registry import register_model @@ -110,7 +110,6 @@ class GhostBottleneck(nn.Module): nn.BatchNorm2d(out_chs), ) - def forward(self, x): shortcut = x diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py index 4420172b..231bb4b6 100644 --- a/timm/models/hardcorenas.py +++ b/timm/models/hardcorenas.py @@ -1,10 +1,14 @@ +from functools import partial + import torch.nn as nn -from .efficientnet_builder import decode_arch_def, resolve_bn_args -from .mobilenetv3 import MobileNetV3, MobileNetV3Features, build_model_with_cfg, default_cfg_for_features -from .layers import hard_sigmoid -from .efficientnet_blocks import resolve_act_layer -from .registry import register_model + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .efficientnet_blocks import SqueezeExcite +from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args +from .helpers import build_model_with_cfg, default_cfg_for_features +from .layers import get_act_fn +from .mobilenetv3 import MobileNetV3, MobileNetV3Features +from .registry import register_model def _cfg(url='', **kwargs): @@ -35,15 +39,15 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): """ num_features = 1280 - + se_layer = partial( + SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8) model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=num_features, stem_size=32, - channel_multiplier=1, - norm_kwargs=resolve_bn_args(kwargs), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'hard_swish'), - se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), + se_layer=se_layer, **kwargs, ) diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py index 7a738d5c..64573ef6 100644 --- a/timm/models/layers/helpers.py +++ b/timm/models/layers/helpers.py @@ -22,10 +22,10 @@ to_4tuple = _ntuple(4) to_ntuple = _ntuple -def make_divisible(v, divisor=8, min_value=None): +def make_divisible(v, divisor=8, min_value=None, round_limit=.9): min_value = min_value or divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: + if new_v < round_limit * v: new_v += divisor - return new_v + return new_v \ No newline at end of file diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 543b33ea..9afa3d75 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -5,23 +5,25 @@ A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 -Hacked together by / Copyright 2020 Ross Wightman +Hacked together by / Copyright 2021 Ross Wightman """ +from functools import partial +from typing import List + import torch import torch.nn as nn import torch.nn.functional as F -from typing import List - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT -from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights +from .efficientnet_blocks import SqueezeExcite +from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\ + round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .features import FeatureInfo, FeatureHooks from .helpers import build_model_with_cfg, default_cfg_for_features from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid from .registry import register_model -__all__ = ['MobileNetV3'] +__all__ = ['MobileNetV3', 'MobileNetV3Features'] def _cfg(url='', **kwargs): @@ -47,9 +49,11 @@ default_cfgs = { url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_in21k_miil.pth', num_classes=11221), 'mobilenetv3_small_075': _cfg(url=''), 'mobilenetv3_small_100': _cfg(url=''), + 'mobilenetv3_rw': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', interpolation='bicubic'), + 'tf_mobilenetv3_large_075': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), @@ -70,8 +74,6 @@ default_cfgs = { mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), } -_DEBUG = False - class MobileNetV3(nn.Module): """ MobiletNet-V3 @@ -84,24 +86,26 @@ class MobileNetV3(nn.Module): """ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, - channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., - se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): + pad_type='', act_layer=None, norm_layer=None, se_layer=None, + round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'): super(MobileNetV3, self).__init__() - + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite self.num_classes = num_classes self.num_features = num_features self.drop_rate = drop_rate # Stem - stem_size = round_channels(stem_size, channel_multiplier) + stem_size = round_chs_fn(stem_size) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) - self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.bn1 = norm_layer(stem_size) self.act1 = act_layer(inplace=True) # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( - channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs, - norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG) + output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, + act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features head_chs = builder.in_chs @@ -158,23 +162,25 @@ class MobileNetV3Features(nn.Module): """ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', - in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='', - act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None, - norm_layer=nn.BatchNorm2d, norm_kwargs=None): + in_chans=3, stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, + act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): super(MobileNetV3Features, self).__init__() - norm_kwargs = norm_kwargs or {} + act_layer = act_layer or nn.ReLU + norm_layer = norm_layer or nn.BatchNorm2d + se_layer = se_layer or SqueezeExcite self.drop_rate = drop_rate # Stem - stem_size = round_channels(stem_size, channel_multiplier) + stem_size = round_chs_fn(stem_size) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) - self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.bn1 = norm_layer(stem_size) self.act1 = act_layer(inplace=True) # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( - channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs, - norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG) + output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, + act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, + drop_path_rate=drop_path_rate, feature_location=feature_location) self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = FeatureInfo(builder.features, out_indices) self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices} @@ -253,10 +259,10 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw model_kwargs = dict( block_args=decode_arch_def(arch_def), head_bias=False, - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'hard_swish'), - se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1), + se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), reduce_from_block=False), **kwargs, ) model = _create_mnv3(variant, pretrained, **model_kwargs) @@ -344,15 +350,16 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg # stage 6, 7x7 in ['cn_r1_k1_s1_c960'], # hard-swish ] - + se_layer = partial( + SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8) model_kwargs = dict( block_args=decode_arch_def(arch_def), num_features=num_features, stem_size=16, - channel_multiplier=channel_multiplier, - norm_kwargs=resolve_bn_args(kwargs), + round_chs_fn=partial(round_channels, multiplier=channel_multiplier), + norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=act_layer, - se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), + se_layer=se_layer, **kwargs, ) model = _create_mnv3(variant, pretrained, **model_kwargs)