diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index 9a8c347f..9460e9af 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -88,6 +88,14 @@ default_cfgs = { url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942), 'efficientnet_b7': _cfg( url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'efficientnet_es': _cfg( + url=''), + 'efficientnet_em': _cfg( + url='', + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'efficientnet_el': _cfg( + url='', + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), 'tf_efficientnet_b0': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', input_size=(3, 224, 224)), @@ -112,6 +120,18 @@ default_cfgs = { 'tf_efficientnet_b7': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949), + 'tf_efficientnet_es': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 224, 224), ), + 'tf_efficientnet_em': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882), + 'tf_efficientnet_el': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), 'mixnet_s': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'), 'mixnet_m': _cfg( @@ -239,6 +259,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0): act_fn = options['n'] if 'n' in options else None exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def num_repeat = int(options['r']) # each type of block has different valid arguments, fill accordingly @@ -267,6 +288,19 @@ def _decode_block_str(block_str, depth_multiplier=1.0): pw_act=block_type == 'dsa', noskip=block_type == 'dsa' or noskip, ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + fake_in_chs=fake_in_chs, + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_fn=act_fn, + noskip=noskip, + ) elif block_type == 'cn': block_args = dict( block_type=block_type, @@ -356,6 +390,9 @@ class _BlockBuilder: bt = ba.pop('block_type') ba['in_chs'] = self.in_chs ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + # FIXME this is a hack to work around mismatch in origin impl input filters + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) ba['bn_args'] = self.bn_args ba['pad_type'] = self.pad_type # block act fn overrides the model default @@ -373,6 +410,13 @@ class _BlockBuilder: if self.verbose: logging.info(' DepthwiseSeparable {}, Args: {}'.format(self.block_idx, str(ba))) block = DepthwiseSeparableConv(**ba) + elif bt == 'er': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_gate_fn'] = self.se_gate_fn + ba['se_reduce_mid'] = self.se_reduce_mid + if self.verbose: + logging.info(' EdgeResidual {}, Args: {}'.format(self.block_idx, str(ba))) + block = EdgeResidual(**ba) elif bt == 'cn': if self.verbose: logging.info(' ConvBnAct {}, Args: {}'.format(self.block_idx, str(ba))) @@ -519,10 +563,62 @@ class ConvBnAct(nn.Module): return x +class EdgeResidual(nn.Module): + """ Residual block with expansion convolution followed by pointwise-linear w/ stride""" + + def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, + stride=1, pad_type='', act_fn=F.relu, noskip=False, pw_kernel_size=1, + se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid, + bn_args=_BN_ARGS_PT, drop_connect_rate=0.): + super(EdgeResidual, self).__init__() + mid_chs = int(fake_in_chs * exp_ratio) if fake_in_chs > 0 else int(in_chs * exp_ratio) + self.has_se = se_ratio is not None and se_ratio > 0. + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.act_fn = act_fn + self.drop_connect_rate = drop_connect_rate + + # Expansion convolution + self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) + self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args) + + # Squeeze-and-excitation + if self.has_se: + se_base_chs = mid_chs if se_reduce_mid else in_chs + self.se = SqueezeExcite( + mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn) + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type) + self.bn2 = nn.BatchNorm2d(out_chs, **bn_args) + + def forward(self, x): + residual = x + + # Expansion convolution + x = self.conv_exp(x) + x = self.bn1(x) + x = self.act_fn(x, inplace=True) + + # Squeeze-and-excitation + if self.has_se: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + + return x + + class DepthwiseSeparableConv(nn.Module): """ DepthwiseSeparable block Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion - factor of 1.0. This is an alternative to having a IR with optional first pw conv. + factor of 1.0. This is an alternative to having a IR with an optional first pw conv. """ def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1, pad_type='', act_fn=F.relu, noskip=False, @@ -1092,7 +1188,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes= ['ir_r4_k5_s2_e6_c192_se0.25'], ['ir_r1_k3_s1_e6_c320_se0.25'], ] - # NOTE: other models in the family didn't scale the feature count num_features = _round_channels(1280, channel_multiplier, 8, None) model = GenEfficientNet( _decode_arch_def(arch_def, depth_multiplier), @@ -1107,6 +1202,31 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes= return model +def _gen_efficientnet_edge(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs): + arch_def = [ + # NOTE `fc` is present to override a mismatch between stem channels and in chs not + # present in other models + ['er_r1_k3_s1_e4_c24_fc24_noskip'], + ['er_r2_k3_s2_e8_c32'], + ['er_r4_k3_s2_e8_c48'], + ['ir_r5_k5_s2_e8_c96'], + ['ir_r4_k5_s1_e8_c144'], + ['ir_r2_k5_s2_e8_c192'], + ] + num_features = _round_channels(1280, channel_multiplier, 8, None) + model = GenEfficientNet( + _decode_arch_def(arch_def, depth_multiplier), + num_classes=num_classes, + stem_size=32, + channel_multiplier=channel_multiplier, + num_features=num_features, + bn_args=_resolve_bn_args(kwargs), + act_fn=F.relu, + **kwargs + ) + return model + + def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs): """Creates a MixNet Small model. @@ -1481,7 +1601,6 @@ def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model - @register_model def efficientnet_b6(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B6 """ @@ -1512,6 +1631,45 @@ def efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def efficientnet_es(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ EfficientNet-Edge Small. """ + default_cfg = default_cfgs['efficientnet_es'] + model = _gen_efficientnet_edge( + channel_multiplier=1.0, depth_multiplier=1.0, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def efficientnet_em(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ EfficientNet-Edge-Medium. """ + default_cfg = default_cfgs['efficientnet_em'] + model = _gen_efficientnet_edge( + channel_multiplier=1.0, depth_multiplier=1.1, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def efficientnet_el(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ EfficientNet-Edge-Large. """ + default_cfg = default_cfgs['efficientnet_el'] + model = _gen_efficientnet_edge( + channel_multiplier=1.2, depth_multiplier=1.4, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def tf_efficientnet_b0(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B0. Tensorflow compatible variant """ @@ -1634,6 +1792,51 @@ def tf_efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs) return model +@register_model +def tf_efficientnet_es(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ EfficientNet-Edge Small. Tensorflow compatible variant """ + default_cfg = default_cfgs['tf_efficientnet_es'] + kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + channel_multiplier=1.0, depth_multiplier=1.0, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def tf_efficientnet_em(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ + default_cfg = default_cfgs['tf_efficientnet_em'] + kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + channel_multiplier=1.0, depth_multiplier=1.1, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def tf_efficientnet_el(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ EfficientNet-Edge-Large. Tensorflow compatible variant """ + default_cfg = default_cfgs['tf_efficientnet_el'] + kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + channel_multiplier=1.2, depth_multiplier=1.4, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Creates a MixNet Small model.