@ -193,8 +193,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False,
help = ' disable fast prefetcher ' )
help = ' disable fast prefetcher ' )
parser . add_argument ( ' --output ' , default = ' ' , type = str , metavar = ' PATH ' ,
parser . add_argument ( ' --output ' , default = ' ' , type = str , metavar = ' PATH ' ,
help = ' path to output folder (default: none, current dir) ' )
help = ' path to output folder (default: none, current dir) ' )
parser . add_argument ( ' --eval-metric ' , default = ' prec 1' , type = str , metavar = ' EVAL_METRIC ' ,
parser . add_argument ( ' --eval-metric ' , default = ' to p1' , type = str , metavar = ' EVAL_METRIC ' ,
help = ' Best metric (default: " prec 1" ' )
help = ' Best metric (default: " to p1" ' )
parser . add_argument ( ' --tta ' , type = int , default = 0 , metavar = ' N ' ,
parser . add_argument ( ' --tta ' , type = int , default = 0 , metavar = ' N ' ,
help = ' Test/inference time augmentation (oversampling) factor. 0=None (default: 0) ' )
help = ' Test/inference time augmentation (oversampling) factor. 0=None (default: 0) ' )
parser . add_argument ( " --local_rank " , default = 0 , type = int )
parser . add_argument ( " --local_rank " , default = 0 , type = int )
@ -596,8 +596,8 @@ def train_epoch(
def validate ( model , loader , loss_fn , args , log_suffix = ' ' ) :
def validate ( model , loader , loss_fn , args , log_suffix = ' ' ) :
batch_time_m = AverageMeter ( )
batch_time_m = AverageMeter ( )
losses_m = AverageMeter ( )
losses_m = AverageMeter ( )
prec 1_m = AverageMeter ( )
to p1_m = AverageMeter ( )
prec 5_m = AverageMeter ( )
to p5_m = AverageMeter ( )
model . eval ( )
model . eval ( )
@ -621,20 +621,20 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
target = target [ 0 : target . size ( 0 ) : reduce_factor ]
target = target [ 0 : target . size ( 0 ) : reduce_factor ]
loss = loss_fn ( output , target )
loss = loss_fn ( output , target )
prec1, pre c5 = accuracy ( output , target , topk = ( 1 , 5 ) )
acc1, ac c5 = accuracy ( output , target , topk = ( 1 , 5 ) )
if args . distributed :
if args . distributed :
reduced_loss = reduce_tensor ( loss . data , args . world_size )
reduced_loss = reduce_tensor ( loss . data , args . world_size )
prec1 = reduce_tensor ( pre c1, args . world_size )
acc1 = reduce_tensor ( ac c1, args . world_size )
prec5 = reduce_tensor ( pre c5, args . world_size )
acc5 = reduce_tensor ( ac c5, args . world_size )
else :
else :
reduced_loss = loss . data
reduced_loss = loss . data
torch . cuda . synchronize ( )
torch . cuda . synchronize ( )
losses_m . update ( reduced_loss . item ( ) , input . size ( 0 ) )
losses_m . update ( reduced_loss . item ( ) , input . size ( 0 ) )
prec1_m. update ( pre c1. item ( ) , output . size ( 0 ) )
top1_m. update ( ac c1. item ( ) , output . size ( 0 ) )
prec5_m. update ( pre c5. item ( ) , output . size ( 0 ) )
top5_m. update ( ac c5. item ( ) , output . size ( 0 ) )
batch_time_m . update ( time . time ( ) - end )
batch_time_m . update ( time . time ( ) - end )
end = time . time ( )
end = time . time ( )
@ -644,13 +644,12 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
' {0} : [ {1:>4d} / {2} ] '
' {0} : [ {1:>4d} / {2} ] '
' Time: {batch_time.val:.3f} ( {batch_time.avg:.3f} ) '
' Time: {batch_time.val:.3f} ( {batch_time.avg:.3f} ) '
' Loss: {loss.val:>7.4f} ( {loss.avg:>6.4f} ) '
' Loss: {loss.val:>7.4f} ( {loss.avg:>6.4f} ) '
' Prec@1: {top1.val:>7.4f} ( {top1.avg:>7.4f} ) '
' Acc@1: {top1.val:>7.4f} ( {top1.avg:>7.4f} ) '
' Prec@5: {top5.val:>7.4f} ( {top5.avg:>7.4f} ) ' . format (
' Acc@5: {top5.val:>7.4f} ( {top5.avg:>7.4f} ) ' . format (
log_name , batch_idx , last_idx ,
log_name , batch_idx , last_idx , batch_time = batch_time_m ,
batch_time = batch_time_m , loss = losses_m ,
loss = losses_m , top1 = top1_m , top5 = top5_m ) )
top1 = prec1_m , top5 = prec5_m ) )
metrics = OrderedDict ( [ ( ' loss ' , losses_m . avg ) , ( ' prec 1' , prec 1_m. avg ) , ( ' prec 5' , prec 5_m. avg ) ] )
metrics = OrderedDict ( [ ( ' loss ' , losses_m . avg ) , ( ' to p1' , to p1_m. avg ) , ( ' to p5' , to p5_m. avg ) ] )
return metrics
return metrics