From 13cf68850bd993d8b763941a965df7317cd63af2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 10 Apr 2020 14:41:08 -0700 Subject: [PATCH] Remove poorly named metrics from torch imagenet example origins. Use top1/top5 in csv output for consistency with existing validation results files, acc elsewhere. Fixes #111 --- timm/utils.py | 3 +-- train.py | 29 ++++++++++++++--------------- validate.py | 12 ++++++------ 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/timm/utils.py b/timm/utils.py index 8957d564..2cae024d 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -170,10 +170,9 @@ class AverageMeter: def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" + """Computes the accuracy over the k top predictions for the specified values of k""" maxk = max(topk) batch_size = target.size(0) - _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) diff --git a/train.py b/train.py index a6fd0e47..e88640a7 100755 --- a/train.py +++ b/train.py @@ -193,8 +193,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') -parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC', - help='Best metric (default: "prec1"') +parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', + help='Best metric (default: "top1"') parser.add_argument('--tta', type=int, default=0, metavar='N', help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') parser.add_argument("--local_rank", default=0, type=int) @@ -596,8 +596,8 @@ def train_epoch( def validate(model, loader, loss_fn, args, log_suffix=''): batch_time_m = AverageMeter() losses_m = AverageMeter() - prec1_m = AverageMeter() - prec5_m = AverageMeter() + top1_m = AverageMeter() + top5_m = AverageMeter() model.eval() @@ -621,20 +621,20 @@ def validate(model, loader, loss_fn, args, log_suffix=''): target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) - prec1, prec5 = accuracy(output, target, topk=(1, 5)) + acc1, acc5 = accuracy(output, target, topk=(1, 5)) if args.distributed: reduced_loss = reduce_tensor(loss.data, args.world_size) - prec1 = reduce_tensor(prec1, args.world_size) - prec5 = reduce_tensor(prec5, args.world_size) + acc1 = reduce_tensor(acc1, args.world_size) + acc5 = reduce_tensor(acc5, args.world_size) else: reduced_loss = loss.data torch.cuda.synchronize() losses_m.update(reduced_loss.item(), input.size(0)) - prec1_m.update(prec1.item(), output.size(0)) - prec5_m.update(prec5.item(), output.size(0)) + top1_m.update(acc1.item(), output.size(0)) + top5_m.update(acc5.item(), output.size(0)) batch_time_m.update(time.time() - end) end = time.time() @@ -644,13 +644,12 @@ def validate(model, loader, loss_fn, args, log_suffix=''): '{0}: [{1:>4d}/{2}] ' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' - 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' - 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( - log_name, batch_idx, last_idx, - batch_time=batch_time_m, loss=losses_m, - top1=prec1_m, top5=prec5_m)) + 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' + 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( + log_name, batch_idx, last_idx, batch_time=batch_time_m, + loss=losses_m, top1=top1_m, top5=top5_m)) - metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)]) + metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) return metrics diff --git a/validate.py b/validate.py index 6ed32ab1..34ce95c0 100755 --- a/validate.py +++ b/validate.py @@ -150,10 +150,10 @@ def validate(args): loss = criterion(output, target) # measure accuracy and record loss - prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + acc1, acc5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) - top1.update(prec1.item(), input.size(0)) - top5.update(prec5.item(), input.size(0)) + top1.update(acc1.item(), input.size(0)) + top5.update(acc5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) @@ -164,8 +164,8 @@ def validate(args): 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' - 'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' - 'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( + 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' + 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) @@ -178,7 +178,7 @@ def validate(args): cropt_pct=crop_pct, interpolation=data_config['interpolation']) - logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( + logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results