|
|
|
@ -428,6 +428,9 @@ def validate(model, loader, loss_fn, args):
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
for batch_idx, (input, target) in enumerate(loader):
|
|
|
|
|
last_batch = batch_idx == last_idx
|
|
|
|
|
if not args.prefetcher:
|
|
|
|
|
input = input.cuda()
|
|
|
|
|
target = target.cuda()
|
|
|
|
|
|
|
|
|
|
output = model(input)
|
|
|
|
|
if isinstance(output, (tuple, list)):
|
|
|
|
|