From 2df77ee5cbc41851566cc63a28acf8623ff04e91 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 15 Apr 2021 10:20:26 -0700 Subject: [PATCH] Fix torchscript compat and features_only behaviour in GhostNet PR. A few minor formatting changes. Reuse existing layers. --- timm/models/ghostnet.py | 135 +++++++++++++--------------------------- 1 file changed, 44 insertions(+), 91 deletions(-) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index ffb168b2..76761d1c 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -4,13 +4,17 @@ GhostNet: More Features from Cheap Operations. https://arxiv.org/abs/1911.11907 The train script of the model is similar to that of MobileNetV3 Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch """ +import math +from functools import partial + import torch import torch.nn as nn import torch.nn.functional as F -import math + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid +from .efficientnet_blocks import SqueezeExcite, ConvBnAct, make_divisible from .helpers import build_model_with_cfg from .registry import register_model @@ -36,62 +40,7 @@ default_cfgs = { } -def _make_divisible(v, divisor, min_value=None): - """ - This function is taken from the original tf repo. - It ensures that all layers have a channel number that is divisible by 8 - It can be seen here: - https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py - """ - if min_value is None: - min_value = divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - # Make sure that round down does not go down by more than 10%. - if new_v < 0.9 * v: - new_v += divisor - return new_v - - -def hard_sigmoid(x, inplace: bool = False): - if inplace: - return x.add_(3.).clamp_(0., 6.).div_(6.) - else: - return F.relu6(x + 3.) / 6. - - -class SqueezeExcite(nn.Module): - def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, - act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_): - super(SqueezeExcite, self).__init__() - self.gate_fn = gate_fn - reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) - self.act1 = act_layer(inplace=True) - self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) - - def forward(self, x): - x_se = self.avg_pool(x) - x_se = self.conv_reduce(x_se) - x_se = self.act1(x_se) - x_se = self.conv_expand(x_se) - x = x * self.gate_fn(x_se) - return x - - -class ConvBnAct(nn.Module): - def __init__(self, in_chs, out_chs, kernel_size, - stride=1, act_layer=nn.ReLU): - super(ConvBnAct, self).__init__() - self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False) - self.bn1 = nn.BatchNorm2d(out_chs) - self.act1 = act_layer(inplace=True) - - def forward(self, x): - x = self.conv(x) - x = self.bn1(x) - x = self.act1(x) - return x +_SE_LAYER = partial(SqueezeExcite, gate_fn=hard_sigmoid, divisor=4) class GhostModule(nn.Module): @@ -99,7 +48,7 @@ class GhostModule(nn.Module): super(GhostModule, self).__init__() self.oup = oup init_channels = math.ceil(oup / ratio) - new_channels = init_channels*(ratio-1) + new_channels = init_channels * (ratio - 1) self.primary_conv = nn.Sequential( nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), @@ -116,8 +65,8 @@ class GhostModule(nn.Module): def forward(self, x): x1 = self.primary_conv(x) x2 = self.cheap_operation(x1) - out = torch.cat([x1,x2], dim=1) - return out[:,:self.oup,:,:] + out = torch.cat([x1, x2], dim=1) + return out[:, :self.oup, :, :] class GhostBottleneck(nn.Module): @@ -134,27 +83,28 @@ class GhostBottleneck(nn.Module): # Depth-wise convolution if self.stride > 1: - self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride, - padding=(dw_kernel_size-1)//2, - groups=mid_chs, bias=False) + self.conv_dw = nn.Conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False) self.bn_dw = nn.BatchNorm2d(mid_chs) + else: + self.conv_dw = None + self.bn_dw = None # Squeeze-and-excitation - if has_se: - self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) - else: - self.se = None + self.se = _SE_LAYER(mid_chs, se_ratio=se_ratio) if has_se else None # Point-wise linear projection self.ghost2 = GhostModule(mid_chs, out_chs, relu=False) # shortcut - if (in_chs == out_chs and self.stride == 1): + if in_chs == out_chs and self.stride == 1: self.shortcut = nn.Sequential() else: self.shortcut = nn.Sequential( - nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride=stride, - padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), + nn.Conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), nn.BatchNorm2d(in_chs), nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), nn.BatchNorm2d(out_chs), @@ -168,7 +118,7 @@ class GhostBottleneck(nn.Module): x = self.ghost1(x) # Depth-wise convolution - if self.stride > 1: + if self.conv_dw is not None: x = self.conv_dw(x) x = self.bn_dw(x) @@ -184,52 +134,55 @@ class GhostBottleneck(nn.Module): class GhostNet(nn.Module): - def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3): + def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32): super(GhostNet, self).__init__() # setting of inverted residual blocks + assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported' self.cfgs = cfgs self.num_classes = num_classes self.dropout = dropout self.feature_info = [] # building first layer - output_channel = _make_divisible(16 * width, 4) - self.conv_stem = nn.Conv2d(in_chans, output_channel, 3, 2, 1, bias=False) - self.feature_info.append(dict(num_chs=output_channel, reduction=2, module=f'conv_stem')) - self.bn1 = nn.BatchNorm2d(output_channel) + stem_chs = make_divisible(16 * width, 4) + self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False) + self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem')) + self.bn1 = nn.BatchNorm2d(stem_chs) self.act1 = nn.ReLU(inplace=True) - input_channel = output_channel + prev_chs = stem_chs # building inverted residual blocks stages = nn.ModuleList([]) block = GhostBottleneck stage_idx = 0 + net_stride = 2 for cfg in self.cfgs: layers = [] + s = 1 for k, exp_size, c, se_ratio, s in cfg: - output_channel = _make_divisible(c * width, 4) - hidden_channel = _make_divisible(exp_size * width, 4) - layers.append(block(input_channel, hidden_channel, output_channel, k, s, - se_ratio=se_ratio)) - input_channel = output_channel + out_chs = make_divisible(c * width, 4) + mid_chs = make_divisible(exp_size * width, 4) + layers.append(block(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio)) + prev_chs = out_chs if s > 1: - self.feature_info.append(dict(num_chs=output_channel, reduction=2**(stage_idx+2), - module=f'blocks.{stage_idx}')) + net_stride *= 2 + self.feature_info.append(dict( + num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}')) stages.append(nn.Sequential(*layers)) stage_idx += 1 - output_channel = _make_divisible(exp_size * width, 4) - stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1))) - self.pool_dim = input_channel = output_channel + out_chs = make_divisible(exp_size * width, 4) + stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1))) + self.pool_dim = prev_chs = out_chs self.blocks = nn.Sequential(*stages) # building last several layers - self.num_features = output_channel = 1280 + self.num_features = out_chs = 1280 self.global_pool = SelectAdaptivePool2d(pool_type='avg') - self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True) + self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) self.act2 = nn.ReLU(inplace=True) - self.classifier = nn.Linear(output_channel, num_classes) + self.classifier = Linear(out_chs, num_classes) def get_classifier(self): return self.classifier