Add model based wd skip support. Improve cross version compat of optimizer factory. Fix #247

pull/250/head
Ross Wightman 4 years ago
parent 80078c47bb
commit a4d8fea61e

@ -41,7 +41,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
opt_lower = args.opt.lower()
weight_decay = args.weight_decay
if weight_decay and filter_bias_and_bn:
parameters = add_weight_decay(model, weight_decay)
skip = {}
if hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay
parameters = add_weight_decay(model, weight_decay, skip)
weight_decay = 0.
else:
parameters = model.parameters()
@ -50,9 +53,9 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
if args.opt_eps is not None:
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
opt_args['eps'] = args.opt_eps
if args.opt_betas is not None:
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
opt_args['betas'] = args.opt_betas
opt_split = opt_lower.split('_')

Loading…
Cancel
Save