From 7236ba08e2c4a3ff75026e1bbea76548f925095d Mon Sep 17 00:00:00 2001 From: iamhankai Date: Sun, 11 Apr 2021 20:32:19 +0800 Subject: [PATCH] Update ghostnet --- tests/test_models.py | 6 +++--- timm/models/ghostnet.py | 35 +++++++++++++++++++++++++---------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 4fbdc85b..d1df7868 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -109,8 +109,8 @@ def test_model_default_cfgs(model_name, batch_size): model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through outputs = model.forward(input_tensor) assert len(outputs.shape) == 4 - if not isinstance(model, timm.models.MobileNetV3): - # FIXME mobilenetv3 forward_features vs removed pooling differ + if not isinstance(model, timm.models.GhostNet) or not isinstance(model, timm.models.GhostNet): + # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] # check classifier name matches default_cfg @@ -143,7 +143,7 @@ if 'GITHUB_ACTIONS' not in os.environ: EXCLUDE_JIT_FILTERS = [ '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable - 'dla*', 'hrnet*', # hopefully fix at some point + 'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point ] diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 67e4c343..54c50820 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -10,6 +10,7 @@ import torch.nn.functional as F import math from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .layers import SelectAdaptivePool2d from .helpers import build_model_with_cfg from .registry import register_model @@ -30,7 +31,7 @@ def _cfg(url='', **kwargs): default_cfgs = { 'ghostnet_050': _cfg(url=''), 'ghostnet_100': _cfg( - url='https://github.com/huawei-noah/CV-backbones/blob/master/ghostnet_pytorch/models/state_dict_73.98.pth'), + url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'), 'ghostnet_130': _cfg(url=''), } @@ -189,17 +190,20 @@ class GhostNet(nn.Module): self.cfgs = cfgs self.num_classes = num_classes self.dropout = dropout + self.feature_info = [] # building first layer output_channel = _make_divisible(16 * width, 4) self.conv_stem = nn.Conv2d(in_chans, output_channel, 3, 2, 1, bias=False) + self.feature_info.append(dict(num_chs=output_channel, reduction=2, module=f'conv_stem')) self.bn1 = nn.BatchNorm2d(output_channel) self.act1 = nn.ReLU(inplace=True) input_channel = output_channel # building inverted residual blocks - stages = [] + stages = nn.ModuleList([]) block = GhostBottleneck + stage_idx = 0 for cfg in self.cfgs: layers = [] for k, exp_size, c, se_ratio, s in cfg: @@ -208,17 +212,21 @@ class GhostNet(nn.Module): layers.append(block(input_channel, hidden_channel, output_channel, k, s, se_ratio=se_ratio)) input_channel = output_channel + if s > 1: + self.feature_info.append(dict(num_chs=output_channel, reduction=2**(stage_idx+2), + module=f'blocks.{stage_idx}')) stages.append(nn.Sequential(*layers)) + stage_idx += 1 output_channel = _make_divisible(exp_size * width, 4) stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1))) - input_channel = output_channel + self.pool_dim = input_channel = output_channel self.blocks = nn.Sequential(*stages) # building last several layers - output_channel = 1280 - self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.num_features = output_channel = 1280 + self.global_pool = SelectAdaptivePool2d(pool_type='avg') self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True) self.act2 = nn.ReLU(inplace=True) self.classifier = nn.Linear(output_channel, num_classes) @@ -226,6 +234,12 @@ class GhostNet(nn.Module): def get_classifier(self): return self.classifier + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + # cannot meaningfully change pooling of efficient head after creation + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity() + def forward_features(self, x): x = self.conv_stem(x) x = self.bn1(x) @@ -234,13 +248,14 @@ class GhostNet(nn.Module): x = self.global_pool(x) x = self.conv_head(x) x = self.act2(x) - x = x.view(x.size(0), -1) - if self.dropout > 0.: - x = F.dropout(x, p=self.dropout, training=self.training) return x def forward(self, x): x = self.forward_features(x) + if not self.global_pool.is_identity(): + x = x.view(x.size(0), -1) + if self.dropout > 0.: + x = F.dropout(x, p=self.dropout, training=self.training) x = self.classifier(x) return x @@ -283,6 +298,7 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs): return build_model_with_cfg( GhostNet, variant, pretrained, default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), **model_kwargs) @@ -301,8 +317,7 @@ def ghostnet_100(pretrained=False, **kwargs): @register_model -def ghostnet_100(pretrained=False, **kwargs): +def ghostnet_130(pretrained=False, **kwargs): """ GhostNet-1.3x """ model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs) return model -