Merge pull request #1586 from lorenzbaraldi/eval_loss

Put validation loss under amp_autocast
pull/1578/merge
Ross Wightman 2 years ago committed by GitHub
commit f266f841a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -970,16 +970,16 @@ def validate(
with amp_autocast(): with amp_autocast():
output = model(input) output = model(input)
if isinstance(output, (tuple, list)): if isinstance(output, (tuple, list)):
output = output[0] output = output[0]
# augmentation reduction # augmentation reduction
reduce_factor = args.tta reduce_factor = args.tta
if reduce_factor > 1: if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor] target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target) loss = loss_fn(output, target)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
if args.distributed: if args.distributed:

@ -294,9 +294,9 @@ def validate(args):
with amp_autocast(): with amp_autocast():
output = model(input) output = model(input)
if valid_labels is not None: if valid_labels is not None:
output = output[:, valid_labels] output = output[:, valid_labels]
loss = criterion(output, target) loss = criterion(output, target)
if real_labels is not None: if real_labels is not None:
real_labels.add_result(output) real_labels.add_result(output)

Loading…
Cancel
Save