From 3d6bc42aa15b67883e3bf0f92df92fc7b74030b1 Mon Sep 17 00:00:00 2001 From: Lorenzo Baraldi Date: Fri, 9 Dec 2022 12:03:23 +0100 Subject: [PATCH] Put validation loss under amp_autocast Secured the loss evaluation under the amp, avoiding function to operate on float16 --- train.py | 16 ++++++++-------- validate.py | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/train.py b/train.py index d40ff04b..b85eb6b0 100755 --- a/train.py +++ b/train.py @@ -970,16 +970,16 @@ def validate( with amp_autocast(): output = model(input) - if isinstance(output, (tuple, list)): - output = output[0] + if isinstance(output, (tuple, list)): + output = output[0] - # augmentation reduction - reduce_factor = args.tta - if reduce_factor > 1: - output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) - target = target[0:target.size(0):reduce_factor] + # augmentation reduction + reduce_factor = args.tta + if reduce_factor > 1: + output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) + target = target[0:target.size(0):reduce_factor] - loss = loss_fn(output, target) + loss = loss_fn(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) if args.distributed: diff --git a/validate.py b/validate.py index 6b8222b9..872f27b0 100755 --- a/validate.py +++ b/validate.py @@ -294,9 +294,9 @@ def validate(args): with amp_autocast(): output = model(input) - if valid_labels is not None: - output = output[:, valid_labels] - loss = criterion(output, target) + if valid_labels is not None: + output = output[:, valid_labels] + loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output)