diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 91466752..67e4c343 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -226,7 +226,7 @@ class GhostNet(nn.Module): def get_classifier(self): return self.classifier - def forward(self, x): + def forward_features(self, x): x = self.conv_stem(x) x = self.bn1(x) x = self.act1(x) @@ -237,6 +237,10 @@ class GhostNet(nn.Module): x = x.view(x.size(0), -1) if self.dropout > 0.: x = F.dropout(x, p=self.dropout, training=self.training) + return x + + def forward(self, x): + x = self.forward_features(x) x = self.classifier(x) return x @@ -276,7 +280,6 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs): width=width, **kwargs, ) - print(model_kwargs) return build_model_with_cfg( GhostNet, variant, pretrained, default_cfg=default_cfgs[variant],