|
|
|
@ -145,7 +145,8 @@ def validate(args):
|
|
|
|
|
model.eval()
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
|
|
|
|
model(torch.randn((args.batch_size,) + data_config['input_size']).cuda())
|
|
|
|
|
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
|
|
|
|
|
model(input)
|
|
|
|
|
end = time.time()
|
|
|
|
|
for i, (input, target) in enumerate(loader):
|
|
|
|
|
if args.no_prefetcher:
|
|
|
|
@ -238,6 +239,7 @@ def main():
|
|
|
|
|
raise e
|
|
|
|
|
batch_size = max(batch_size // 2, args.num_gpu)
|
|
|
|
|
print("Validation failed, reducing batch size by 50%")
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
result.update(r)
|
|
|
|
|
if args.checkpoint:
|
|
|
|
|
result['checkpoint'] = args.checkpoint
|
|
|
|
|