diff --git a/validate.py b/validate.py index f782910a..29ef7d5e 100755 --- a/validate.py +++ b/validate.py @@ -218,22 +218,17 @@ def validate(args): loss=losses, top1=top1, top5=top5)) if real_labels is not None: - real_top1 = real_labels.get_accuracy(k=1) - real_top5 = real_labels.get_accuracy(k=5) - results = OrderedDict( - top1=round(real_top1, 4), top1_err=round(100 - real_top1, 4), - top5=round(real_top5, 4), top5_err=round(100 - real_top5, 4), - top1_original=round(top1.avg, 4), - top5_original=round(top5.avg, 4)) + # real labels mode replaces topk values at the end + top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: - results = OrderedDict( - top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), - top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4)) - results.update(OrderedDict( + top1a, top5a = top1.avg, top5.avg + results = OrderedDict( + top1=round(top1a, 4), top1_err=round(100 - top1a, 4), + top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, - interpolation=data_config['interpolation'])) + interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err']))