From b0bd2884e7d1c7ea012230ba4959f709c65d019d Mon Sep 17 00:00:00 2001 From: iamhankai Date: Thu, 8 Apr 2021 08:24:18 +0800 Subject: [PATCH] Fix GhostNet bug --- timm/models/ghostnet.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 91466752..67e4c343 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -226,7 +226,7 @@ class GhostNet(nn.Module): def get_classifier(self): return self.classifier - def forward(self, x): + def forward_features(self, x): x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) @@ -237,6 +237,10 @@ class GhostNet(nn.Module): x = x.view(x.size(0), -1) if self.dropout > 0.: x = F.dropout(x, p=self.dropout, training=self.training) + return x + + def forward(self, x): + x = self.forward_features(x) x = self.classifier(x) return x @@ -276,7 +280,6 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs): width=width, **kwargs, ) - print(model_kwargs) return build_model_with_cfg( GhostNet, variant, pretrained, default_cfg=default_cfgs[variant],