From 53c47479c46df9f5bfa155e14aa20ff198dc3b15 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 15 Feb 2020 20:37:04 -0800 Subject: [PATCH] Batch validation batch size adjustment, tweak L2 crop pct --- timm/models/efficientnet.py | 2 +- validate.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 0892f8c1..993db6aa 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -194,7 +194,7 @@ default_cfgs = { input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936), 'tf_efficientnet_l2_ns': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', - input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961), + input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96), 'tf_efficientnet_es': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), diff --git a/validate.py b/validate.py index 7993faaa..6ed32ab1 100755 --- a/validate.py +++ b/validate.py @@ -211,11 +211,24 @@ def main(): logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) results = [] try: + start_batch_size = args.batch_size for m, c in model_cfgs: + batch_size = start_batch_size args.model = m args.checkpoint = c result = OrderedDict(model=args.model) - r = validate(args) + r = {} + while not r and batch_size >= args.num_gpu: + try: + args.batch_size = batch_size + print('Validating with batch size: %d' % args.batch_size) + r = validate(args) + except RuntimeError as e: + if batch_size <= args.num_gpu: + print("Validation failed with no ability to reduce batch size. Exiting.") + raise e + batch_size = max(batch_size // 2, args.num_gpu) + print("Validation failed, reducing batch size by 50%") result.update(r) if args.checkpoint: result['checkpoint'] = args.checkpoint