|
|
@ -109,7 +109,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
|
|
|
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
|
|
|
|
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
|
|
|
|
outputs = model.forward(input_tensor)
|
|
|
|
outputs = model.forward(input_tensor)
|
|
|
|
assert len(outputs.shape) == 4
|
|
|
|
assert len(outputs.shape) == 4
|
|
|
|
if not isinstance(model, timm.models.GhostNet) or not isinstance(model, timm.models.GhostNet):
|
|
|
|
if not isinstance(model, timm.models.MobileNetV3) or not isinstance(model, timm.models.GhostNet):
|
|
|
|
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
|
|
|
|
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
|
|
|
|
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
|
|
|
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
|
|
|
|
|
|
|
|
|
|
|
|