Replace `getattr` with `hasattr` in `get_state_dict`

Avoid throwing an `AttributeError: '...' object has no attribute 'module'` on non-(Distributed)DataParallel modules.
pull/13/head
Maxim Berman 6 years ago committed by GitHub
parent 1d7f2d93a6
commit ebafb4238e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,7 +18,7 @@ def get_state_dict(model):
if isinstance(model, ModelEma): if isinstance(model, ModelEma):
return get_state_dict(model.ema) return get_state_dict(model.ema)
else: else:
return model.module.state_dict() if getattr(model, 'module') else model.state_dict() return model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
class CheckpointSaver: class CheckpointSaver:

Loading…
Cancel
Save