|
|
|
@ -970,16 +970,16 @@ def validate(
|
|
|
|
|
|
|
|
|
|
with amp_autocast():
|
|
|
|
|
output = model(input)
|
|
|
|
|
if isinstance(output, (tuple, list)):
|
|
|
|
|
output = output[0]
|
|
|
|
|
if isinstance(output, (tuple, list)):
|
|
|
|
|
output = output[0]
|
|
|
|
|
|
|
|
|
|
# augmentation reduction
|
|
|
|
|
reduce_factor = args.tta
|
|
|
|
|
if reduce_factor > 1:
|
|
|
|
|
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
|
|
|
|
|
target = target[0:target.size(0):reduce_factor]
|
|
|
|
|
# augmentation reduction
|
|
|
|
|
reduce_factor = args.tta
|
|
|
|
|
if reduce_factor > 1:
|
|
|
|
|
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
|
|
|
|
|
target = target[0:target.size(0):reduce_factor]
|
|
|
|
|
|
|
|
|
|
loss = loss_fn(output, target)
|
|
|
|
|
loss = loss_fn(output, target)
|
|
|
|
|
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
|