diff --git a/tests/test_models.py b/tests/test_models.py index d1df7868..7b30d364 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -109,7 +109,7 @@ def test_model_default_cfgs(model_name, batch_size): model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through outputs = model.forward(input_tensor) assert len(outputs.shape) == 4 - if not isinstance(model, timm.models.GhostNet) or not isinstance(model, timm.models.GhostNet): + if not isinstance(model, timm.models.MobileNetV3) or not isinstance(model, timm.models.GhostNet): # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]