From 820ae9925e66357bbd5b5d7db4e16ec7809c3e62 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 3 Dec 2021 13:22:25 -0800 Subject: [PATCH] Fix load_state_dict to handle None ema entries --- timm/models/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 6aa1faa3..16ce64d0 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -27,9 +27,9 @@ def load_state_dict(checkpoint_path, use_ema=False): checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict_key = '' if isinstance(checkpoint, dict): - if use_ema and 'state_dict_ema' in checkpoint: + if use_ema and checkpoint.get('state_dict_ema', None) is not None: state_dict_key = 'state_dict_ema' - elif use_ema and 'model_ema' in checkpoint: + elif use_ema and checkpoint.get('model_ema', None) is not None: state_dict_key = 'model_ema' elif 'state_dict' in checkpoint: state_dict_key = 'state_dict'