diff --git a/validate.py b/validate.py index 9b2c0f7e..2e18841f 100755 --- a/validate.py +++ b/validate.py @@ -80,8 +80,8 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') -parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', - help='disable test time pool') +parser.add_argument('--test-pool', dest='test_pool', action='store_true', + help='enable test time pool') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--pin-mem', action='store_true', default=False, @@ -154,7 +154,7 @@ def validate(args): data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) test_time_pool = False - if not args.no_test_pool: + if args.test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: