|
|
|
@ -41,6 +41,10 @@ def checkpoint_metric(checkpoint_path):
|
|
|
|
|
metric = None
|
|
|
|
|
if 'metric' in checkpoint:
|
|
|
|
|
metric = checkpoint['metric']
|
|
|
|
|
elif 'metrics' in checkpoint and 'metric_name' in checkpoint:
|
|
|
|
|
metrics = checkpoint['metrics']
|
|
|
|
|
print(metrics)
|
|
|
|
|
metric = metrics[checkpoint['metric_name']]
|
|
|
|
|
return metric
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|