Update validation script first batch prime and clear cuda cache between multi-model runs

pull/155/head
Ross Wightman 4 years ago
parent 0aca08384f
commit d3ee3de96a

@ -145,7 +145,8 @@ def validate(args):
model.eval()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
model(torch.randn((args.batch_size,) + data_config['input_size']).cuda())
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
model(input)
end = time.time()
for i, (input, target) in enumerate(loader):
if args.no_prefetcher:
@ -238,6 +239,7 @@ def main():
raise e
batch_size = max(batch_size // 2, args.num_gpu)
print("Validation failed, reducing batch size by 50%")
torch.cuda.empty_cache()
result.update(r)
if args.checkpoint:
result['checkpoint'] = args.checkpoint

Loading…
Cancel
Save