don't forget this file

pull/244/head
Matthijs Hollemans 4 years ago
parent 8ffdc5910a
commit f04bdc8c8e

@ -73,7 +73,7 @@ def main():
(args.model, sum([m.numel() for m in model.parameters()])))
config = resolve_data_config(vars(args), model=model)
model, test_time_pool = model, False if args.no_test_pool else apply_test_time_pool(model, config)
model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config)
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()

Loading…
Cancel
Save