From 8ffdc5910a915e5077ab426f2f4a37f8c2b863ac Mon Sep 17 00:00:00 2001 From: Matthijs Hollemans Date: Fri, 2 Oct 2020 22:56:15 +0200 Subject: [PATCH 1/2] test_time_pool would be set to a non-False value even if test-time pooling is not available --- validate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validate.py b/validate.py index 780be5a4..5a0d388c 100755 --- a/validate.py +++ b/validate.py @@ -139,7 +139,7 @@ def validate(args): _logger.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) - model, test_time_pool = model, False if args.no_test_pool else apply_test_time_pool(model, data_config) + model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, data_config) if args.torchscript: torch.jit.optimized_execution(True) From f04bdc8c8e6b8c933cfa16d2279b7f18e16d28da Mon Sep 17 00:00:00 2001 From: Matthijs Hollemans Date: Fri, 2 Oct 2020 23:23:44 +0200 Subject: [PATCH 2/2] don't forget this file --- inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference.py b/inference.py index f402d005..16d19944 100755 --- a/inference.py +++ b/inference.py @@ -73,7 +73,7 @@ def main(): (args.model, sum([m.numel() for m in model.parameters()]))) config = resolve_data_config(vars(args), model=model) - model, test_time_pool = model, False if args.no_test_pool else apply_test_time_pool(model, config) + model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()