|
|
|
@ -752,15 +752,16 @@ def train_one_epoch(
|
|
|
|
|
|
|
|
|
|
return OrderedDict([('loss', losses_m.avg)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
|
|
|
|
|
|
|
|
|
|
batch_time_m = AverageMeter()
|
|
|
|
|
losses_m = AverageMeter()
|
|
|
|
|
top1_m = AverageMeter()
|
|
|
|
|
top5_m = AverageMeter()
|
|
|
|
|
|
|
|
|
|
top5_ = AverageMeter()
|
|
|
|
|
top_p = AverageMeter()
|
|
|
|
|
top_r = AverageMeter()
|
|
|
|
|
top_f = AverageMeter()
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
end = time.time()
|
|
|
|
|
last_idx = len(loader) - 1
|
|
|
|
|
with torch.no_grad():
|
|
|
|
@ -771,35 +772,35 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
|
|
|
|
|
target = target.cuda()
|
|
|
|
|
if args.channels_last:
|
|
|
|
|
input = input.contiguous(memory_format=torch.channels_last)
|
|
|
|
|
|
|
|
|
|
with amp_autocast():
|
|
|
|
|
output = model(input)
|
|
|
|
|
if isinstance(output, (tuple, list)):
|
|
|
|
|
output = output[0]
|
|
|
|
|
|
|
|
|
|
# augmentation reduction
|
|
|
|
|
reduce_factor = args.tta
|
|
|
|
|
if reduce_factor > 1:
|
|
|
|
|
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
|
|
|
|
|
target = target[0:target.size(0):reduce_factor]
|
|
|
|
|
|
|
|
|
|
target = target[0:target.size(0):reduce_factor]
|
|
|
|
|
loss = loss_fn(output, target)
|
|
|
|
|
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
|
|
|
|
|
|
|
|
|
# acc1 = accuracy(output, target, topk=(1,1))
|
|
|
|
|
acc1, acc5 = accuracy(output, target, topk=(1, k))
|
|
|
|
|
f1 = f1_scor(output, target)
|
|
|
|
|
prec = precision(output.detach(), target)
|
|
|
|
|
rec = recall(output.detach(), target)
|
|
|
|
|
if args.distributed:
|
|
|
|
|
reduced_loss = reduce_tensor(loss.data, args.world_size)
|
|
|
|
|
acc1 = reduce_tensor(acc1, args.world_size)
|
|
|
|
|
acc5 = reduce_tensor(acc5, args.world_size)
|
|
|
|
|
# acc5 = reduce_tensor(acc5, args.world_size)
|
|
|
|
|
else:
|
|
|
|
|
reduced_loss = loss.data
|
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
losses_m.update(reduced_loss.item(), input.size(0))
|
|
|
|
|
top1_m.update(acc1.item(), output.size(0))
|
|
|
|
|
top5_m.update(acc5.item(), output.size(0))
|
|
|
|
|
|
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
|
losses_m.update(reduced_loss.item(),'acc', input.size(0))
|
|
|
|
|
top1_m.update(acc1.item(),'acc', output.size(0))
|
|
|
|
|
top_p.update(prec,'prec', input.size(0))
|
|
|
|
|
top_r.update(rec,'rec', input.size(0))
|
|
|
|
|
top_f.update(f1,'f1', input.size(0))
|
|
|
|
|
# top5_m.update(acc5.item(), output.size(0))
|
|
|
|
|
batch_time_m.update(time.time() - end,'acc')
|
|
|
|
|
end = time.time()
|
|
|
|
|
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
|
|
|
|
|
log_name = 'Test' + log_suffix
|
|
|
|
@ -808,13 +809,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
|
|
|
|
|
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
|
|
|
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
|
|
|
|
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
|
|
|
|
|
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
|
|
|
|
|
'Precision: {top_p.val:>7.3f} ({top_p.avg_pre:>7.3f}) '
|
|
|
|
|
'Recall: {top_r.val:>7.3f} ({top_r.avg_rec:>7.3f})'
|
|
|
|
|
'F1: {top_f.val:>7.3f} ({top_f.avg_f1:>7.3f})'.format(
|
|
|
|
|
log_name, batch_idx, last_idx, batch_time=batch_time_m,
|
|
|
|
|
loss=losses_m, top1=top1_m, top5=top5_m))
|
|
|
|
|
|
|
|
|
|
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
|
|
|
|
|
|
|
|
|
|
return metrics
|
|
|
|
|
loss=losses_m, top1=top1_m,top_p=top_p,top_r=top_r,top_f=top_f))
|
|
|
|
|
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('f1', top_f.avg_f1), ('Precision',top_p.avg_pre), ('Recall',top_r.avg_rec)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|