diff --git a/train.py b/train.py index d40ff04b..b85eb6b0 100755 --- a/train.py +++ b/train.py @@ -970,16 +970,16 @@ def validate( with amp_autocast(): output = model(input) - if isinstance(output, (tuple, list)): - output = output[0] + if isinstance(output, (tuple, list)): + output = output[0] - # augmentation reduction - reduce_factor = args.tta - if reduce_factor > 1: - output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) - target = target[0:target.size(0):reduce_factor] + # augmentation reduction + reduce_factor = args.tta + if reduce_factor > 1: + output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) + target = target[0:target.size(0):reduce_factor] - loss = loss_fn(output, target) + loss = loss_fn(output, target) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) if args.distributed: diff --git a/validate.py b/validate.py index 6b8222b9..872f27b0 100755 --- a/validate.py +++ b/validate.py @@ -294,9 +294,9 @@ def validate(args): with amp_autocast(): output = model(input) - if valid_labels is not None: - output = output[:, valid_labels] - loss = criterion(output, target) + if valid_labels is not None: + output = output[:, valid_labels] + loss = criterion(output, target) if real_labels is not None: real_labels.add_result(output)