Fix GhostNet bug

pull/548/head
iamhankai 4 years ago
parent 9295f0e85e
commit b0bd2884e7

@ -226,7 +226,7 @@ class GhostNet(nn.Module):
def get_classifier(self):
return self.classifier
def forward(self, x):
def forward_features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
@ -237,6 +237,10 @@ class GhostNet(nn.Module):
x = x.view(x.size(0), -1)
if self.dropout > 0.:
x = F.dropout(x, p=self.dropout, training=self.training)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.classifier(x)
return x
@ -276,7 +280,6 @@ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs):
width=width,
**kwargs,
)
print(model_kwargs)
return build_model_with_cfg(
GhostNet, variant, pretrained,
default_cfg=default_cfgs[variant],

Loading…
Cancel
Save