Pass drop connect arg through to EfficientNet models

pull/52/head
Ross Wightman 5 years ago
parent 31453b039e
commit 7b83e67f77

@ -25,12 +25,13 @@ def create_model(
""" """
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
# Not all models have support for batchnorm params passed as args, only gen_efficientnet variants # Only gen_efficientnet models have support for batchnorm params or drop_connect_rate passed as args
supports_bn_params = is_model_in_modules(model_name, ['gen_efficientnet']) is_efficientnet = is_model_in_modules(model_name, ['gen_efficientnet'])
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]): if not is_efficientnet:
kwargs.pop('bn_tf', None) kwargs.pop('bn_tf', None)
kwargs.pop('bn_momentum', None) kwargs.pop('bn_momentum', None)
kwargs.pop('bn_eps', None) kwargs.pop('bn_eps', None)
kwargs.pop('drop_connect_rate', None)
if is_model(model_name): if is_model(model_name):
create_fn = model_entrypoint(model_name) create_fn = model_entrypoint(model_name)

@ -65,6 +65,8 @@ parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)') help='input batch size for training (default: 32)')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.)') help='Dropout rate (default: 0.)')
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
help='Drop connect rate (default: 0.)')
# Optimizer parameters # Optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"') help='Optimizer (default: "sgd"')
@ -208,6 +210,7 @@ def main():
pretrained=args.pretrained, pretrained=args.pretrained,
num_classes=args.num_classes, num_classes=args.num_classes,
drop_rate=args.drop, drop_rate=args.drop,
drop_connect_rate=args.drop_connect,
global_pool=args.gp, global_pool=args.gp,
bn_tf=args.bn_tf, bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum, bn_momentum=args.bn_momentum,
@ -253,7 +256,7 @@ def main():
if args.local_rank == 0: if args.local_rank == 0:
logging.info('Restoring NVIDIA AMP state from checkpoint') logging.info('Restoring NVIDIA AMP state from checkpoint')
amp.load_state_dict(resume_state['amp']) amp.load_state_dict(resume_state['amp'])
resume_state = None resume_state = None # clear it
model_ema = None model_ema = None
if args.model_ema: if args.model_ema:

Loading…
Cancel
Save