diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index c53be368..80bac373 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -41,7 +41,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True): opt_lower = args.opt.lower() weight_decay = args.weight_decay if weight_decay and filter_bias_and_bn: - parameters = add_weight_decay(model, weight_decay) + skip = {} + if hasattr(model, 'no_weight_decay'): + skip = model.no_weight_decay + parameters = add_weight_decay(model, weight_decay, skip) weight_decay = 0. else: parameters = model.parameters() @@ -50,9 +53,9 @@ def create_optimizer(args, model, filter_bias_and_bn=True): assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' opt_args = dict(lr=args.lr, weight_decay=weight_decay) - if args.opt_eps is not None: + if hasattr(args, 'opt_eps') and args.opt_eps is not None: opt_args['eps'] = args.opt_eps - if args.opt_betas is not None: + if hasattr(args, 'opt_betas') and args.opt_betas is not None: opt_args['betas'] = args.opt_betas opt_split = opt_lower.split('_')