diff --git a/avg_checkpoints.py b/avg_checkpoints.py index 1f7604b0..ea8bbe84 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -41,6 +41,10 @@ def checkpoint_metric(checkpoint_path): metric = None if 'metric' in checkpoint: metric = checkpoint['metric'] + elif 'metrics' in checkpoint and 'metric_name' in checkpoint: + metrics = checkpoint['metrics'] + print(metrics) + metric = metrics[checkpoint['metric_name']] return metric