diff --git a/train.py b/train.py index 63003763..38d86992 100644 --- a/train.py +++ b/train.py @@ -428,6 +428,9 @@ def validate(model, loader, loss_fn, args): with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx + if not args.prefetcher: + input = input.cuda() + target = target.cuda() output = model(input) if isinstance(output, (tuple, list)):