From 7c7ecd24923b19338ca083d56369193e153294f0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 22:01:24 -0700 Subject: [PATCH] Add --use-train-size flag to force use of train input_size (over test input size) for validation. Default test-time pooling to use train input size (fixes issues). --- timm/models/layers/test_time_pool.py | 2 +- validate.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/timm/models/layers/test_time_pool.py b/timm/models/layers/test_time_pool.py index 98c0bf53..5826d8c9 100644 --- a/timm/models/layers/test_time_pool.py +++ b/timm/models/layers/test_time_pool.py @@ -36,7 +36,7 @@ class TestTimePoolHead(nn.Module): return x.view(x.size(0), -1) -def apply_test_time_pool(model, config, use_test_size=True): +def apply_test_time_pool(model, config, use_test_size=False): test_time_pool = False if not hasattr(model, 'default_cfg') or not model.default_cfg: return model, False diff --git a/validate.py b/validate.py index 708ac2e5..7fa22b49 100755 --- a/validate.py +++ b/validate.py @@ -67,6 +67,8 @@ parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') +parser.add_argument('--use-train-size', action='store_true', default=False, + help='force use of train input size, even when test size is specified in pretrained cfg') parser.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop pct') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', @@ -164,10 +166,15 @@ def validate(args): param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) - data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) + data_config = resolve_data_config( + vars(args), + model=model, + use_test_size=not args.use_train_size, + verbose=True + ) test_time_pool = False if args.test_pool: - model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) + model, test_time_pool = apply_test_time_pool(model, data_config) if args.torchscript: torch.jit.optimized_execution(True)