Update test_models.py

pull/548/head
Kai Han 4 years ago committed by GitHub
parent b85650093f
commit 4b2952825e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -109,7 +109,7 @@ def test_model_default_cfgs(model_name, batch_size):
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
outputs = model.forward(input_tensor) outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4 assert len(outputs.shape) == 4
if not isinstance(model, timm.models.MobileNetV3) or not isinstance(model, timm.models.GhostNet): if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet):
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]

Loading…
Cancel
Save