From b2094f4ee845d89aca8de65ae9b6ae09829a8b8e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 3 Oct 2021 17:31:22 -0700 Subject: [PATCH] support bits checkpoints in avg/load --- avg_checkpoints.py | 4 ++++ timm/models/helpers.py | 15 +++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/avg_checkpoints.py b/avg_checkpoints.py index 1f7604b0..ea8bbe84 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -41,6 +41,10 @@ def checkpoint_metric(checkpoint_path): metric = None if 'metric' in checkpoint: metric = checkpoint['metric'] + elif 'metrics' in checkpoint and 'metric_name' in checkpoint: + metrics = checkpoint['metrics'] + print(metrics) + metric = metrics[checkpoint['metric_name']] return metric diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 662a7a48..bd97cf20 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -24,13 +24,20 @@ _logger = logging.getLogger(__name__) def load_state_dict(checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') - state_dict_key = 'state_dict' + state_dict_key = '' if isinstance(checkpoint, dict): - if use_ema and 'state_dict_ema' in checkpoint: + if use_ema and checkpoint.get('state_dict_ema', None) is not None: state_dict_key = 'state_dict_ema' - if state_dict_key and state_dict_key in checkpoint: + elif use_ema and checkpoint.get('model_ema', None) is not None: + state_dict_key = 'model_ema' + elif 'state_dict' in checkpoint: + state_dict_key = 'state_dict' + elif 'model' in checkpoint: + state_dict_key = 'model' + if state_dict_key: + state_dict = checkpoint[state_dict_key] new_state_dict = OrderedDict() - for k, v in checkpoint[state_dict_key].items(): + for k, v in state_dict.items(): # strip `module.` prefix name = k[7:] if k.startswith('module') else k new_state_dict[name] = v