diff --git a/timm/models/factory.py b/timm/models/factory.py index a7b6c90e..736b5d1d 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -55,6 +55,7 @@ def create_model( raise RuntimeError('Unknown model (%s)' % model_name) if checkpoint_path: - load_checkpoint(model, checkpoint_path) + features_only = kwargs.get("features_only", False) + load_checkpoint(model, checkpoint_path, strict=not features_only) return model