Fix load_state_dict to handle None ema entries

pull/1239/head
Ross Wightman 3 years ago
parent 0e212e8fe5
commit 820ae9925e

@ -27,9 +27,9 @@ def load_state_dict(checkpoint_path, use_ema=False):
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = '' state_dict_key = ''
if isinstance(checkpoint, dict): if isinstance(checkpoint, dict):
if use_ema and 'state_dict_ema' in checkpoint: if use_ema and checkpoint.get('state_dict_ema', None) is not None:
state_dict_key = 'state_dict_ema' state_dict_key = 'state_dict_ema'
elif use_ema and 'model_ema' in checkpoint: elif use_ema and checkpoint.get('model_ema', None) is not None:
state_dict_key = 'model_ema' state_dict_key = 'model_ema'
elif 'state_dict' in checkpoint: elif 'state_dict' in checkpoint:
state_dict_key = 'state_dict' state_dict_key = 'state_dict'

Loading…
Cancel
Save