From 30ab4a1494a8c9f440be940297b31497e0cbf411 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 31 Oct 2020 18:03:35 -0700 Subject: [PATCH] Fix issue in optim factory with sgd / eps flag. Bump version to 0.3.1 --- timm/optim/optim_factory.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index c4a43a2e..ecc61c5f 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -61,10 +61,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True): opt_split = opt_lower.split('_') opt_lower = opt_split[-1] if opt_lower == 'sgd' or opt_lower == 'nesterov': - del opt_args['eps'] + opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'momentum': - del opt_args['eps'] + opt_args.pop('eps', None) optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) @@ -95,10 +95,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True): elif opt_lower == 'nvnovograd': optimizer = NvNovoGrad(parameters, **opt_args) elif opt_lower == 'fusedsgd': - del opt_args['eps'] + opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) elif opt_lower == 'fusedmomentum': - del opt_args['eps'] + opt_args.pop('eps', None) optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) elif opt_lower == 'fusedadam': optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)