From 7b83e67f77122d5c07ccafaa4c09719f947e00b2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 22 Nov 2019 13:27:43 -0800 Subject: [PATCH] Pass drop connect arg through to EfficientNet models --- timm/models/factory.py | 7 ++++--- train.py | 5 ++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/timm/models/factory.py b/timm/models/factory.py index d807a342..3c051e75 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -25,12 +25,13 @@ def create_model( """ margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) - # Not all models have support for batchnorm params passed as args, only gen_efficientnet variants - supports_bn_params = is_model_in_modules(model_name, ['gen_efficientnet']) - if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]): + # Only gen_efficientnet models have support for batchnorm params or drop_connect_rate passed as args + is_efficientnet = is_model_in_modules(model_name, ['gen_efficientnet']) + if not is_efficientnet: kwargs.pop('bn_tf', None) kwargs.pop('bn_momentum', None) kwargs.pop('bn_eps', None) + kwargs.pop('drop_connect_rate', None) if is_model(model_name): create_fn = model_entrypoint(model_name) diff --git a/train.py b/train.py index b0e18bdd..776e5ef2 100644 --- a/train.py +++ b/train.py @@ -65,6 +65,8 @@ parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', help='input batch size for training (default: 32)') parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', help='Dropout rate (default: 0.)') +parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP', + help='Drop connect rate (default: 0.)') # Optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') @@ -208,6 +210,7 @@ def main(): pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, + drop_connect_rate=args.drop_connect, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, @@ -253,7 +256,7 @@ def main(): if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) - resume_state = None + resume_state = None # clear it model_ema = None if args.model_ema: