Update validation script first batch prime and clear cuda cache between multi-model runs

pull/155/head
Ross Wightman 5 years ago
parent 0aca08384f
commit d3ee3de96a

@ -145,7 +145,8 @@ def validate(args):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
model(torch.randn((args.batch_size,) + data_config['input_size']).cuda()) input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
model(input)
end = time.time() end = time.time()
for i, (input, target) in enumerate(loader): for i, (input, target) in enumerate(loader):
if args.no_prefetcher: if args.no_prefetcher:
@ -238,6 +239,7 @@ def main():
raise e raise e
batch_size = max(batch_size // 2, args.num_gpu) batch_size = max(batch_size // 2, args.num_gpu)
print("Validation failed, reducing batch size by 50%") print("Validation failed, reducing batch size by 50%")
torch.cuda.empty_cache()
result.update(r) result.update(r)
if args.checkpoint: if args.checkpoint:
result['checkpoint'] = args.checkpoint result['checkpoint'] = args.checkpoint

Loading…
Cancel
Save