Update ghostnet

pull/548/head
iamhankai 4 years ago
parent b0bd2884e7
commit 7236ba08e2

@ -109,8 +109,8 @@ def test_model_default_cfgs(model_name, batch_size):
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
outputs = model.forward(input_tensor) outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4 assert len(outputs.shape) == 4
if not isinstance(model, timm.models.MobileNetV3): if not isinstance(model, timm.models.GhostNet) or not isinstance(model, timm.models.GhostNet):
# FIXME mobilenetv3 forward_features vs removed pooling differ # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
# check classifier name matches default_cfg # check classifier name matches default_cfg
@ -143,7 +143,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
EXCLUDE_JIT_FILTERS = [ EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', # hopefully fix at some point 'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
] ]

@ -10,6 +10,7 @@ import torch.nn.functional as F
import math import math
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import SelectAdaptivePool2d
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .registry import register_model from .registry import register_model
@ -30,7 +31,7 @@ def _cfg(url='', **kwargs):
default_cfgs = { default_cfgs = {
'ghostnet_050': _cfg(url=''), 'ghostnet_050': _cfg(url=''),
'ghostnet_100': _cfg( 'ghostnet_100': _cfg(
url='https://github.com/huawei-noah/CV-backbones/blob/master/ghostnet_pytorch/models/state_dict_73.98.pth'), url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'),
'ghostnet_130': _cfg(url=''), 'ghostnet_130': _cfg(url=''),
} }
@ -189,17 +190,20 @@ class GhostNet(nn.Module):
self.cfgs = cfgs self.cfgs = cfgs
self.num_classes = num_classes self.num_classes = num_classes
self.dropout = dropout self.dropout = dropout
self.feature_info = []
# building first layer # building first layer
output_channel = _make_divisible(16 * width, 4) output_channel = _make_divisible(16 * width, 4)
self.conv_stem = nn.Conv2d(in_chans, output_channel, 3, 2, 1, bias=False) self.conv_stem = nn.Conv2d(in_chans, output_channel, 3, 2, 1, bias=False)
self.feature_info.append(dict(num_chs=output_channel, reduction=2, module=f'conv_stem'))
self.bn1 = nn.BatchNorm2d(output_channel) self.bn1 = nn.BatchNorm2d(output_channel)
self.act1 = nn.ReLU(inplace=True) self.act1 = nn.ReLU(inplace=True)
input_channel = output_channel input_channel = output_channel
# building inverted residual blocks # building inverted residual blocks
stages = [] stages = nn.ModuleList([])
block = GhostBottleneck block = GhostBottleneck
stage_idx = 0
for cfg in self.cfgs: for cfg in self.cfgs:
layers = [] layers = []
for k, exp_size, c, se_ratio, s in cfg: for k, exp_size, c, se_ratio, s in cfg:
@ -208,17 +212,21 @@ class GhostNet(nn.Module):
layers.append(block(input_channel, hidden_channel, output_channel, k, s, layers.append(block(input_channel, hidden_channel, output_channel, k, s,
se_ratio=se_ratio)) se_ratio=se_ratio))
input_channel = output_channel input_channel = output_channel
if s > 1:
self.feature_info.append(dict(num_chs=output_channel, reduction=2**(stage_idx+2),
module=f'blocks.{stage_idx}'))
stages.append(nn.Sequential(*layers)) stages.append(nn.Sequential(*layers))
stage_idx += 1
output_channel = _make_divisible(exp_size * width, 4) output_channel = _make_divisible(exp_size * width, 4)
stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1))) stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1)))
input_channel = output_channel self.pool_dim = input_channel = output_channel
self.blocks = nn.Sequential(*stages) self.blocks = nn.Sequential(*stages)
# building last several layers # building last several layers
output_channel = 1280 self.num_features = output_channel = 1280
self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) self.global_pool = SelectAdaptivePool2d(pool_type='avg')
self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True) self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True)
self.act2 = nn.ReLU(inplace=True) self.act2 = nn.ReLU(inplace=True)
self.classifier = nn.Linear(output_channel, num_classes) self.classifier = nn.Linear(output_channel, num_classes)
@ -226,6 +234,12 @@ class GhostNet(nn.Module):
def get_classifier(self): def get_classifier(self):
return self.classifier return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
@ -234,13 +248,14 @@ class GhostNet(nn.Module):
x = self.global_pool(x) x = self.global_pool(x)
x = self.conv_head(x) x = self.conv_head(x)
x = self.act2(x) x = self.act2(x)
x = x.view(x.size(0), -1)
if self.dropout > 0.:
x = F.dropout(x, p=self.dropout, training=self.training)
return x return x
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
if not self.global_pool.is_identity():
x = x.view(x.size(0), -1)
if self.dropout > 0.:
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.classifier(x) x = self.classifier(x)
return x return x
@ -283,6 +298,7 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs):
return build_model_with_cfg( return build_model_with_cfg(
GhostNet, variant, pretrained, GhostNet, variant, pretrained,
default_cfg=default_cfgs[variant], default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True),
**model_kwargs) **model_kwargs)
@ -301,8 +317,7 @@ def ghostnet_100(pretrained=False, **kwargs):
@register_model @register_model
def ghostnet_100(pretrained=False, **kwargs): def ghostnet_130(pretrained=False, **kwargs):
""" GhostNet-1.3x """ """ GhostNet-1.3x """
model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs) model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
return model return model

Loading…
Cancel
Save