From 5e95ced5a7763541f7219f35fd155e3fbfe66e8b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 23 Jun 2021 11:10:05 -0700 Subject: [PATCH] timm bits checkpoint support for avg_checkpoints.py --- avg_checkpoints.py | 4 ++++ timm/models/helpers.py | 13 ++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/avg_checkpoints.py b/avg_checkpoints.py index 1f7604b0..ea8bbe84 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -41,6 +41,10 @@ def checkpoint_metric(checkpoint_path): metric = None if 'metric' in checkpoint: metric = checkpoint['metric'] + elif 'metrics' in checkpoint and 'metric_name' in checkpoint: + metrics = checkpoint['metrics'] + print(metrics) + metric = metrics[checkpoint['metric_name']] return metric diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 662a7a48..9dafeefa 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -24,13 +24,20 @@ _logger = logging.getLogger(__name__) def load_state_dict(checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') - state_dict_key = 'state_dict' + state_dict_key = '' if isinstance(checkpoint, dict): if use_ema and 'state_dict_ema' in checkpoint: state_dict_key = 'state_dict_ema' - if state_dict_key and state_dict_key in checkpoint: + elif use_ema and 'model_ema' in checkpoint: + state_dict_key = 'model_ema' + elif 'state_dict' in checkpoint: + state_dict_key = 'state_dict' + elif 'model' in checkpoint: + state_dict_key = 'model' + if state_dict_key: + state_dict = checkpoint[state_dict_key] new_state_dict = OrderedDict() - for k, v in checkpoint[state_dict_key].items(): + for k, v in state_dict.items(): # strip `module.` prefix name = k[7:] if k.startswith('module') else k new_state_dict[name] = v