diff --git a/inference.py b/inference.py index f402d005..16d19944 100755 --- a/inference.py +++ b/inference.py @@ -73,7 +73,7 @@ def main(): (args.model, sum([m.numel() for m in model.parameters()]))) config = resolve_data_config(vars(args), model=model) - model, test_time_pool = model, False if args.no_test_pool else apply_test_time_pool(model, config) + model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()