|
|
@ -10,6 +10,7 @@ import torch.nn.functional as F
|
|
|
|
import math
|
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
|
|
|
from .layers import SelectAdaptivePool2d
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
from .registry import register_model
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
@ -30,7 +31,7 @@ def _cfg(url='', **kwargs):
|
|
|
|
default_cfgs = {
|
|
|
|
default_cfgs = {
|
|
|
|
'ghostnet_050': _cfg(url=''),
|
|
|
|
'ghostnet_050': _cfg(url=''),
|
|
|
|
'ghostnet_100': _cfg(
|
|
|
|
'ghostnet_100': _cfg(
|
|
|
|
url='https://github.com/huawei-noah/CV-backbones/blob/master/ghostnet_pytorch/models/state_dict_73.98.pth'),
|
|
|
|
url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'),
|
|
|
|
'ghostnet_130': _cfg(url=''),
|
|
|
|
'ghostnet_130': _cfg(url=''),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -189,17 +190,20 @@ class GhostNet(nn.Module):
|
|
|
|
self.cfgs = cfgs
|
|
|
|
self.cfgs = cfgs
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.dropout = dropout
|
|
|
|
self.dropout = dropout
|
|
|
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
|
|
# building first layer
|
|
|
|
# building first layer
|
|
|
|
output_channel = _make_divisible(16 * width, 4)
|
|
|
|
output_channel = _make_divisible(16 * width, 4)
|
|
|
|
self.conv_stem = nn.Conv2d(in_chans, output_channel, 3, 2, 1, bias=False)
|
|
|
|
self.conv_stem = nn.Conv2d(in_chans, output_channel, 3, 2, 1, bias=False)
|
|
|
|
|
|
|
|
self.feature_info.append(dict(num_chs=output_channel, reduction=2, module=f'conv_stem'))
|
|
|
|
self.bn1 = nn.BatchNorm2d(output_channel)
|
|
|
|
self.bn1 = nn.BatchNorm2d(output_channel)
|
|
|
|
self.act1 = nn.ReLU(inplace=True)
|
|
|
|
self.act1 = nn.ReLU(inplace=True)
|
|
|
|
input_channel = output_channel
|
|
|
|
input_channel = output_channel
|
|
|
|
|
|
|
|
|
|
|
|
# building inverted residual blocks
|
|
|
|
# building inverted residual blocks
|
|
|
|
stages = []
|
|
|
|
stages = nn.ModuleList([])
|
|
|
|
block = GhostBottleneck
|
|
|
|
block = GhostBottleneck
|
|
|
|
|
|
|
|
stage_idx = 0
|
|
|
|
for cfg in self.cfgs:
|
|
|
|
for cfg in self.cfgs:
|
|
|
|
layers = []
|
|
|
|
layers = []
|
|
|
|
for k, exp_size, c, se_ratio, s in cfg:
|
|
|
|
for k, exp_size, c, se_ratio, s in cfg:
|
|
|
@ -208,17 +212,21 @@ class GhostNet(nn.Module):
|
|
|
|
layers.append(block(input_channel, hidden_channel, output_channel, k, s,
|
|
|
|
layers.append(block(input_channel, hidden_channel, output_channel, k, s,
|
|
|
|
se_ratio=se_ratio))
|
|
|
|
se_ratio=se_ratio))
|
|
|
|
input_channel = output_channel
|
|
|
|
input_channel = output_channel
|
|
|
|
|
|
|
|
if s > 1:
|
|
|
|
|
|
|
|
self.feature_info.append(dict(num_chs=output_channel, reduction=2**(stage_idx+2),
|
|
|
|
|
|
|
|
module=f'blocks.{stage_idx}'))
|
|
|
|
stages.append(nn.Sequential(*layers))
|
|
|
|
stages.append(nn.Sequential(*layers))
|
|
|
|
|
|
|
|
stage_idx += 1
|
|
|
|
|
|
|
|
|
|
|
|
output_channel = _make_divisible(exp_size * width, 4)
|
|
|
|
output_channel = _make_divisible(exp_size * width, 4)
|
|
|
|
stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1)))
|
|
|
|
stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1)))
|
|
|
|
input_channel = output_channel
|
|
|
|
self.pool_dim = input_channel = output_channel
|
|
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.Sequential(*stages)
|
|
|
|
self.blocks = nn.Sequential(*stages)
|
|
|
|
|
|
|
|
|
|
|
|
# building last several layers
|
|
|
|
# building last several layers
|
|
|
|
output_channel = 1280
|
|
|
|
self.num_features = output_channel = 1280
|
|
|
|
self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type='avg')
|
|
|
|
self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True)
|
|
|
|
self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True)
|
|
|
|
self.act2 = nn.ReLU(inplace=True)
|
|
|
|
self.act2 = nn.ReLU(inplace=True)
|
|
|
|
self.classifier = nn.Linear(output_channel, num_classes)
|
|
|
|
self.classifier = nn.Linear(output_channel, num_classes)
|
|
|
@ -226,6 +234,12 @@ class GhostNet(nn.Module):
|
|
|
|
def get_classifier(self):
|
|
|
|
def get_classifier(self):
|
|
|
|
return self.classifier
|
|
|
|
return self.classifier
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
|
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
|
|
|
# cannot meaningfully change pooling of efficient head after creation
|
|
|
|
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
|
|
|
self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
x = self.conv_stem(x)
|
|
|
|
x = self.conv_stem(x)
|
|
|
|
x = self.bn1(x)
|
|
|
|
x = self.bn1(x)
|
|
|
@ -234,13 +248,14 @@ class GhostNet(nn.Module):
|
|
|
|
x = self.global_pool(x)
|
|
|
|
x = self.global_pool(x)
|
|
|
|
x = self.conv_head(x)
|
|
|
|
x = self.conv_head(x)
|
|
|
|
x = self.act2(x)
|
|
|
|
x = self.act2(x)
|
|
|
|
x = x.view(x.size(0), -1)
|
|
|
|
|
|
|
|
if self.dropout > 0.:
|
|
|
|
|
|
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
|
|
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
|
|
|
if not self.global_pool.is_identity():
|
|
|
|
|
|
|
|
x = x.view(x.size(0), -1)
|
|
|
|
|
|
|
|
if self.dropout > 0.:
|
|
|
|
|
|
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
|
|
x = self.classifier(x)
|
|
|
|
x = self.classifier(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
@ -283,6 +298,7 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs):
|
|
|
|
return build_model_with_cfg(
|
|
|
|
return build_model_with_cfg(
|
|
|
|
GhostNet, variant, pretrained,
|
|
|
|
GhostNet, variant, pretrained,
|
|
|
|
default_cfg=default_cfgs[variant],
|
|
|
|
default_cfg=default_cfgs[variant],
|
|
|
|
|
|
|
|
feature_cfg=dict(flatten_sequential=True),
|
|
|
|
**model_kwargs)
|
|
|
|
**model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -301,8 +317,7 @@ def ghostnet_100(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def ghostnet_100(pretrained=False, **kwargs):
|
|
|
|
def ghostnet_130(pretrained=False, **kwargs):
|
|
|
|
""" GhostNet-1.3x """
|
|
|
|
""" GhostNet-1.3x """
|
|
|
|
model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
|
|
|
|
model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|