Update ghostnet

pull/548/head
iamhankai 4 years ago
parent 7236ba08e2
commit b85650093f

@ -109,7 +109,7 @@ def test_model_default_cfgs(model_name, batch_size):
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, timm.models.GhostNet) or not isinstance(model, timm.models.GhostNet):
if not isinstance(model, timm.models.MobileNetV3) or not isinstance(model, timm.models.GhostNet):
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]

Loading…
Cancel
Save