diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 8ad066c2..101e428a 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -102,6 +102,8 @@ default_cfgs = { 'efficientnet_eca_b2': _cfg( url='', input_size=(3, 260, 260), pool_size=(9, 9)), + 'xefficientnet_b0': _cfg( + url=''), 'efficientnet_es': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'), 'efficientnet_em': _cfg( @@ -242,7 +244,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'), } -_DEBUG = False +_DEBUG = True class EfficientNet(nn.Module): @@ -658,6 +660,54 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre return model +def _gen_xefficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r1_k5_s2_e6_c40', 'xir_r1_k5_s1_e6_c40'], + ['ir_r3_k3_s2_e6_c80'], + ['xir_r3_k5_s1_e6_c112'], + ['ir_r1_k5_s2_e6_c192', 'xir_r3_k5_s1_e6_c192'], + ['xir_r1_k5_s1_e6_c320'], + ] + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=Swish, + attn_layer='sev2', + attn_kwargs=dict(se_ratio=0.25), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): """ Creates an EfficientNet-EdgeTPU model @@ -1064,6 +1114,15 @@ def efficientnet_eca_b2(pretrained=False, **kwargs): return model +@register_model +def xefficientnet_b0(pretrained=False, **kwargs): + """ XEfficientNet-B0 """ + # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2 + model = _gen_xefficientnet( + 'xefficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + @register_model def efficientnet_es(pretrained=False, **kwargs): """ EfficientNet-Edge Small. """ diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index aeacc153..490f8716 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -244,6 +244,161 @@ class InvertedResidual(nn.Module): return x +class XDepthwiseSeparableConv(nn.Module): + """ DepthwiseSeparable block + Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion + (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. + """ + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, + pw_kernel_size=1, pw_act=False, attn_layer=None, attn_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.): + super(XDepthwiseSeparableConv, self).__init__() + norm_kwargs = norm_kwargs or {} + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.has_pw_act = pw_act # activation after point-wise conv + self.drop_path_rate = drop_path_rate + + conv_kwargs = {} + self.conv_dw_2x2 = create_conv2d( + in_chs, in_chs, 2, stride=stride, dilation=dilation, + padding='same', depthwise=True, **conv_kwargs) + self.conv_dw_1xk = create_conv2d( + in_chs, in_chs, (1, dw_kernel_size), stride=stride, dilation=dilation, + padding=pad_type, depthwise=True, **conv_kwargs) + self.conv_dw_kx1 = create_conv2d( + in_chs, in_chs, (dw_kernel_size, 1), stride=stride, dilation=dilation, + padding=pad_type, depthwise=True, **conv_kwargs) + self.bn1 = norm_layer(in_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Attention block (Squeeze-Excitation, ECA, etc) + if attn_layer is not None: + attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer) + self.se = create_attn(attn_layer, in_chs, **attn_kwargs) + else: + self.se = None + + self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() + + def feature_module(self, location): + # no expansion in this block, pre pw only feature extraction point + return 'conv_pw' + + def feature_channels(self, location): + return self.conv_pw.in_channels + + def forward(self, x): + residual = x + + x = self.conv_dw_2x2(x) + x = self.conv_dw_1xk(x) + x = self.conv_dw_kx1(x) + x = self.bn1(x) + x = self.act1(x) + + if self.se is not None: + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + x = self.act2(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += residual + return x + + +class XInvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE and CondConv routing""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, pad_shift=0, + attn_layer=None, attn_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + conv_kwargs=None, drop_path_rate=0.): + super(XInvertedResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + conv_kwargs = conv_kwargs or {} + mid_chs = make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_path_rate = drop_path_rate + + # Point-wise expansion + self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw_2x2 = create_conv2d( + mid_chs, mid_chs, 2, stride=stride, dilation=dilation, + padding='same', depthwise=True, pad_shift=pad_shift, **conv_kwargs) + self.conv_dw_1xk = create_conv2d( + mid_chs, mid_chs, (1, dw_kernel_size), stride=stride, dilation=dilation, + padding=pad_type, depthwise=True, **conv_kwargs) + self.conv_dw_kx1 = create_conv2d( + mid_chs, mid_chs, (dw_kernel_size, 1), stride=stride, dilation=dilation, + padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) + + # Attention block (Squeeze-Excitation, ECA, etc) + if attn_layer is not None: + attn_kwargs = resolve_attn_args(attn_layer, attn_kwargs, in_chs, act_layer) + self.se = create_attn(attn_layer, mid_chs, **attn_kwargs) + else: + self.se = None + + # Point-wise linear projection + self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.bn3 = norm_layer(out_chs, **norm_kwargs) + + def feature_module(self, location): + if location == 'post_exp': + return 'act1' + return 'conv_pwl' + + def feature_channels(self, location): + if location == 'post_exp': + return self.conv_pw.out_channels + # location == 'pre_pw' + return self.conv_pwl.in_channels + + def forward(self, x): + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw_2x2(x) + x = self.conv_dw_1xk(x) + x = self.conv_dw_kx1(x) + x = self.bn2(x) + x = self.act2(x) + + # Attention + if self.se is not None: + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_path_rate > 0.: + x = drop_path(x, self.drop_path_rate, self.training) + x += residual + + return x + + class CondConvResidual(InvertedResidual): """ Inverted residual block w/ CondConv routing""" diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index bf13d60e..5975f14f 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -90,7 +90,7 @@ def _decode_block_str(block_str): num_repeat = int(options['r']) # each type of block has different valid arguments, fill accordingly - if block_type == 'ir': + if block_type == 'ir' or block_type == 'xir': block_args = dict( block_type=block_type, dw_kernel_size=_parse_ksize(options['k']), @@ -106,7 +106,7 @@ def _decode_block_str(block_str): ) if 'cc' in options: block_args['num_experts'] = int(options['cc']) - elif block_type == 'ds' or block_type == 'dsa': + elif block_type == 'ds' or block_type == 'dsa' or block_type == 'xds': block_args = dict( block_type=block_type, dw_kernel_size=_parse_ksize(options['k']), @@ -232,6 +232,7 @@ class EfficientNetBuilder: # state updated during build, consumed by model self.in_chs = None + self.x_count = 0 self.features = OrderedDict() def _round_channels(self, chs): @@ -261,33 +262,35 @@ class EfficientNetBuilder: ba['attn_kwargs'] = self.attn_kwargs else: ba['attn_kwargs'].update(self.attn_kwargs) + ba['drop_path_rate'] = drop_path_rate if bt == 'ir': - ba['drop_path_rate'] = drop_path_rate - if self.verbose: - logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba))) if ba.get('num_experts', 0) > 0: block = CondConvResidual(**ba) else: block = InvertedResidual(**ba) + elif bt == 'xir': + ba['pad_shift'] = self.x_count + block = XInvertedResidual(**ba) + self.x_count = (self.x_count + 1) % 4 elif bt == 'ds' or bt == 'dsa': - ba['drop_path_rate'] = drop_path_rate - if self.verbose: - logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba))) block = DepthwiseSeparableConv(**ba) + elif bt == 'xds': + ba['pad_shift'] = self.x_count + block = XDepthwiseSeparableConv(**ba) + self.x_count = (self.x_count + 1) % 4 elif bt == 'er': - ba['drop_path_rate'] = drop_path_rate - if self.verbose: - logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba))) block = EdgeResidual(**ba) elif bt == 'cn': - if self.verbose: - logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba))) + del ba['drop_path_rate'] block = ConvBnAct(**ba) else: assert False, 'Uknkown block type (%s) while building model.' % bt self.in_chs = ba['out_chs'] # update in_chs for arg of next block + if self.verbose: + logging.info(' {} {}, Args: {}'.format(block.__class__.__name__, block_idx, str(ba))) + return block def __call__(self, in_chs, model_block_args): diff --git a/timm/models/layers/conv2d_same.py b/timm/models/layers/conv2d_same.py index 0e29ae8c..b3592e50 100644 --- a/timm/models/layers/conv2d_same.py +++ b/timm/models/layers/conv2d_same.py @@ -13,8 +13,8 @@ from .padding import get_padding, pad_same, is_static_pad def conv2d_same( x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), - padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): - x = pad_same(x, weight.shape[-2:], stride, dilation) + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1, pad_shift: int = 0): + x = pad_same(x, weight.shape[-2:], stride, dilation, pad_shift) return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) @@ -23,12 +23,14 @@ class Conv2dSame(nn.Conv2d): """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True): + padding=0, dilation=1, groups=1, bias=True, pad_shift=0): super(Conv2dSame, self).__init__( in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + self.pad_shift = pad_shift def forward(self, x): - return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return conv2d_same( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.pad_shift) def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py index b1d72ba1..4d0c5040 100644 --- a/timm/models/layers/helpers.py +++ b/timm/models/layers/helpers.py @@ -7,7 +7,7 @@ from torch._six import container_abcs # From PyTorch internals -def _ntuple(n): +def ntuple(n): def parse(x): if isinstance(x, container_abcs.Iterable): return x @@ -15,10 +15,10 @@ def _ntuple(n): return parse -tup_single = _ntuple(1) -tup_pair = _ntuple(2) -tup_triple = _ntuple(3) -tup_quadruple = _ntuple(4) +tup_single = ntuple(1) +tup_pair = ntuple(2) +tup_triple = ntuple(3) +tup_quadruple = ntuple(4) def make_divisible(v, divisor=8, min_value=None): diff --git a/timm/models/layers/padding.py b/timm/models/layers/padding.py index b3653866..436ed54f 100644 --- a/timm/models/layers/padding.py +++ b/timm/models/layers/padding.py @@ -4,12 +4,17 @@ Hacked together by Ross Wightman """ import math from typing import List +from .helpers import ntuple import torch.nn.functional as F # Calculate symmetric padding for a convolution def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + if isinstance(kernel_size, (list, tuple)): + stride = ntuple(len(kernel_size))(stride) + dilation = ntuple(len(kernel_size))(dilation) + return [get_padding(k, s, d) for k, s, d in zip(kernel_size, stride, dilation)] padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding @@ -25,9 +30,17 @@ def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): # Dynamically pad input x with 'SAME' padding for conv with specified args -def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)): +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), shift: int = 0): ih, iw = x.size()[-2:] pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + if shift == 0: + pl = [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] # ul + elif shift == 1: + pl = [pad_w - pad_w // 2, pad_w // 2, pad_h - pad_h // 2, pad_h // 2] # lr + elif shift == 2: + pl = [pad_w - pad_w // 2, pad_w // 2, pad_h // 2, pad_h - pad_h // 2] # ur + else: + pl = [pad_w // 2, pad_w - pad_w // 2, pad_h - pad_h // 2, pad_h // 2] # ll + x = F.pad(x, pl) return x diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 5ac66112..9c1d8b4d 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -35,6 +35,7 @@ default_cfgs = { 'mobilenetv3_small_075': _cfg(url='', interoplation='bicubic'), 'mobilenetv3_small_100': _cfg(url='', interoplation='bicubic'), 'mobilenetv3_eca_large': _cfg(url='', interoplation='bicubic'), + 'xmobilenetv3_large_100': _cfg(url='', interoplation='bicubic'), 'mobilenetv3_rw': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', interpolation='bicubic'), @@ -58,7 +59,7 @@ default_cfgs = { mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), } -_DEBUG = False +_DEBUG = True class MobileNetV3(nn.Module): @@ -360,6 +361,102 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg return model +def _gen_xmobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model. + + Ref impl: ? + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + if 'small' in variant: + num_features = 1024 + if 'minimal' in variant: + act_layer = nn.ReLU + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16'], + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], + # stage 2, 28x28 in + ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], + # stage 3, 14x14 in + ['ir_r2_k3_s1_e3_c48'], + # stage 4, 14x14in + ['ir_r3_k3_s2_e6_c96'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], + ] + else: + act_layer = HardSwish + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu + # stage 2, 28x28 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish + # stage 3, 14x14 in + ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish + # stage 4, 14x14in + ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], # hard-swish + ] + else: + num_features = 1280 + if 'minimal' in variant: + act_layer = nn.ReLU + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k3_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112'], + # stage 5, 14x14in + ['ir_r3_k3_s2_e6_c160'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + act_layer = HardSwish + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre', 'xir_r2_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['xir_r2_k5_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r1_k5_s2_e6_c160_se0.25', 'xir_r2_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=16, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=act_layer, + attn_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisible_by=8), + **kwargs, + ) + model = _create_model(model_kwargs, default_cfgs[variant], pretrained) + return model + + def _gen_mobilenet_v3_eca(variant, channel_multiplier=1.0, pretrained=False, **kwargs): """Creates a MobileNet-V3 model. @@ -471,6 +568,13 @@ def mobilenetv3_eca_large(pretrained=False, **kwargs): return model +@register_model +def xmobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 """ + model = _gen_xmobilenet_v3('xmobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + @register_model def mobilenetv3_rw(pretrained=False, **kwargs): """ MobileNet V3 """