From 7ba5a384d3a485e9d637d9ad986ebbc85bdfb567 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 Jul 2020 10:28:57 -0700 Subject: [PATCH] Add ReXNet w/ remapped weights, feature support --- tests/test_models.py | 2 +- timm/models/__init__.py | 1 + timm/models/layers/conv_bn_act.py | 9 + timm/models/layers/separable_conv.py | 16 ++ timm/models/layers/split_attn.py | 8 + timm/models/rexnet.py | 257 +++++++++++++++++++++++++++ 6 files changed, 292 insertions(+), 1 deletion(-) create mode 100644 timm/models/rexnet.py diff --git a/tests/test_models.py b/tests/test_models.py index 966c5bd8..27e9b515 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -113,7 +113,7 @@ EXCLUDE_FEAT_FILTERS = [ ] if 'GITHUB_ACTIONS' in os.environ and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models - EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d'] + EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d'] @pytest.mark.timeout(120) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index bc3c871a..445fbc5c 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -16,6 +16,7 @@ from .regnet import * from .res2net import * from .resnest import * from .resnet import * +from .rexnet import * from .selecsls import * from .senet import * from .sknet import * diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py index 43f6760e..ea9d0128 100644 --- a/timm/models/layers/conv_bn_act.py +++ b/timm/models/layers/conv_bn_act.py @@ -14,6 +14,7 @@ class ConvBnAct(nn.Module): drop_block=None, aa_layer=None): super(ConvBnAct, self).__init__() use_aa = aa_layer is not None + self.conv = create_conv2d( in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, padding=padding, dilation=dilation, groups=groups, bias=False) @@ -23,6 +24,14 @@ class ConvBnAct(nn.Module): self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None + @property + def in_channels(self): + return self.conv.in_channels + + @property + def out_channels(self): + return self.conv.out_channels + def forward(self, x): x = self.conv(x) x = self.bn(x) diff --git a/timm/models/layers/separable_conv.py b/timm/models/layers/separable_conv.py index 3df0387a..a422e87e 100644 --- a/timm/models/layers/separable_conv.py +++ b/timm/models/layers/separable_conv.py @@ -23,6 +23,14 @@ class SeparableConvBnAct(nn.Module): norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs) self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args) + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + def forward(self, x): x = self.conv_dw(x) x = self.conv_pw(x) @@ -45,6 +53,14 @@ class SeparableConv2d(nn.Module): self.conv_pw = create_conv2d( int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) + @property + def in_channels(self): + return self.conv_dw.in_channels + + @property + def out_channels(self): + return self.conv_pw.out_channels + def forward(self, x): x = self.conv_dw(x) x = self.conv_pw(x) diff --git a/timm/models/layers/split_attn.py b/timm/models/layers/split_attn.py index 023ab6af..5615aa0b 100644 --- a/timm/models/layers/split_attn.py +++ b/timm/models/layers/split_attn.py @@ -51,6 +51,14 @@ class SplitAttnConv2d(nn.Module): self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) self.rsoftmax = RadixSoftmax(radix, groups) + @property + def in_channels(self): + return self.conv.in_channels + + @property + def out_channels(self): + return self.fc1.out_channels + def forward(self, x): x = self.conv(x) if self.bn0 is not None: diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py new file mode 100644 index 00000000..8997d40c --- /dev/null +++ b/timm/models/rexnet.py @@ -0,0 +1,257 @@ +""" ReXNet + +A PyTorch impl of `ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network` - +https://arxiv.org/abs/2007.00992 + +Adapted from original impl at https://github.com/clovaai/rexnet +Copyright (c) 2020-present NAVER Corp. MIT license + +Changes for timm, feature extraction, and rounded channel variant hacked together by Ross Wightman +""" + +import torch.nn as nn +from math import ceil + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .layers import ClassifierHead, create_act_layer, ConvBnAct +from .registry import register_model + + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + } + + +default_cfgs = dict( + rexnet_100=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_100-1b4dddf4.pth'), + rexnet_130=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_130-590d768e.pth'), + rexnet_150=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_150-bd1a6aa8.pth'), + rexnet_200=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_200-8c0b7f2d.pth'), + rexnetr_100=_cfg( + url=''), + rexnetr_130=_cfg( + url=''), + rexnetr_150=_cfg( + url=''), + rexnetr_200=_cfg( + url=''), +) + + +def make_divisible(v, divisor=8, min_value=None): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + return new_v + + +class SEWithNorm(nn.Module): + + def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None, + gate_layer='sigmoid'): + super(SEWithNorm, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor) + print(reduction_channels) + self.fc1 = nn.Conv2d( + channels, reduction_channels, kernel_size=1, padding=0, bias=True) + self.bn = nn.BatchNorm2d(reduction_channels) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d( + reduction_channels, channels, kernel_size=1, padding=0, bias=True) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.fc1(x_se) + x_se = self.bn(x_se) + x_se = self.act(x_se) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +class LinearBottleneck(nn.Module): + def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, use_se=True, se_rd=12, ch_div=1): + super(LinearBottleneck, self).__init__() + self.use_shortcut = stride == 1 and in_chs <= out_chs + self.in_channels = in_chs + self.out_channels = out_chs + + if exp_ratio != 1.: + dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div) + self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer="swish") + else: + dw_chs = in_chs + self.conv_exp = None + + self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False) + self.se = SEWithNorm(dw_chs, reduction=se_rd, divisor=ch_div) if use_se else None + self.act_dw = nn.ReLU6() + + self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False) + + def feat_channels(self, exp=False): + return self.conv_dw.out_channels if exp else self.out_channels + + def forward(self, x): + shortcut = x + if self.conv_exp is not None: + x = self.conv_exp(x) + x = self.conv_dw(x) + if self.se is not None: + x = self.se(x) + x = self.act_dw(x) + x = self.conv_pwl(x) + if self.use_shortcut: + x[:, 0:self.in_channels] += shortcut + return x + + +def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, use_se=True, ch_div=1): + layers = [1, 2, 2, 3, 3, 5] + strides = [1, 2, 2, 2, 1, 2] + layers = [ceil(element * depth_mult) for element in layers] + strides = sum([[element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides)], []) + exp_ratios = [1] * layers[0] + [6] * sum(layers[1:]) + depth = sum(layers[:]) * 3 + base_chs = initial_chs / width_mult if width_mult < 1.0 else initial_chs + + # The following channel configuration is a simple instance to make each layer become an expand layer. + out_chs_list = [] + for i in range(depth // 3): + out_chs_list.append(make_divisible(round(base_chs * width_mult), divisor=ch_div)) + base_chs += final_chs / (depth // 3 * 1.0) + + if use_se: + use_ses = [False] * (layers[0] + layers[1]) + [True] * sum(layers[2:]) + else: + use_ses = [False] * sum(layers[:]) + + return zip(out_chs_list, exp_ratios, strides, use_ses) + + +def _build_blocks(block_cfg, prev_chs, width_mult, se_rd=12, ch_div=1, feature_location='bottleneck'): + feat_exp = feature_location == 'expansion' + feat_chs = [prev_chs] + feature_info = [] + curr_stride = 2 + features = [] + for block_idx, (chs, exp_ratio, stride, se) in enumerate(block_cfg): + if stride > 1: + fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}' + if block_idx > 0 and feat_exp: + fname += '.act_dw' + feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)] + curr_stride *= stride + features.append(LinearBottleneck( + in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, use_se=se, se_rd=se_rd, ch_div=ch_div)) + prev_chs = chs + feat_chs += [features[-1].feat_channels(feat_exp)] + pen_chs = make_divisible(1280 * width_mult, divisor=ch_div) + feature_info += [dict( + num_chs=pen_chs if feat_exp else feat_chs[-1], reduction=curr_stride, + module=f'features.{len(features) - int(not feat_exp)}')] + features.append(ConvBnAct(prev_chs, pen_chs, act_layer="swish")) + return features, feature_info + + +class ReXNetV1(nn.Module): + def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, + initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, use_se=True, + se_rd=12, ch_div=1, drop_rate=0.2, feature_location='bottleneck'): + super(ReXNetV1, self).__init__() + + assert output_stride == 32 # FIXME support dilation + stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32 + stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div) + self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer='swish') + + block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, use_se, ch_div) + features, self.feature_info = _build_blocks( + block_cfg, stem_chs, width_mult, se_rd, ch_div, feature_location) + self.num_features = features[-1].out_channels + self.features = nn.Sequential(*features) + + self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate) + + # FIXME weight init, the original appears to use PyTorch defaults + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.features(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_rexnet(variant, pretrained, **kwargs): + feature_cfg = dict(flatten_sequential=True) + if kwargs.get('feature_location', '') == 'expansion': + feature_cfg['feature_cls'] = 'hook' + return build_model_with_cfg( + ReXNetV1, variant, pretrained, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, **kwargs) + + +@register_model +def rexnet_100(pretrained=False, **kwargs): + """ReXNet V1 1.0x""" + return _create_rexnet('rexnet_100', pretrained, **kwargs) + + +@register_model +def rexnet_130(pretrained=False, **kwargs): + """ReXNet V1 1.3x""" + return _create_rexnet('rexnet_130', pretrained, width_mult=1.3, **kwargs) + + +@register_model +def rexnet_150(pretrained=False, **kwargs): + """ReXNet V1 1.5x""" + return _create_rexnet('rexnet_150', pretrained, width_mult=1.5, **kwargs) + + +@register_model +def rexnet_200(pretrained=False, **kwargs): + """ReXNet V1 2.0x""" + return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs) + + +@register_model +def rexnetr_100(pretrained=False, **kwargs): + """ReXNet V1 1.0x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_100', pretrained, ch_div=8, **kwargs) + + +@register_model +def rexnetr_130(pretrained=False, **kwargs): + """ReXNet V1 1.3x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_130', pretrained, width_mult=1.3, ch_div=8, **kwargs) + + +@register_model +def rexnetr_150(pretrained=False, **kwargs): + """ReXNet V1 1.5x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_150', pretrained, width_mult=1.5, ch_div=8, **kwargs) + + +@register_model +def rexnetr_200(pretrained=False, **kwargs): + """ReXNet V1 2.0x w/ rounded (mod 8) channels""" + return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs)