Fix issue in optim factory with sgd / eps flag. Bump version to 0.3.1

pull/268/head
Ross Wightman 4 years ago
parent 46f15443be
commit 30ab4a1494

@ -61,10 +61,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
opt_split = opt_lower.split('_')
opt_lower = opt_split[-1]
if opt_lower == 'sgd' or opt_lower == 'nesterov':
del opt_args['eps']
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
elif opt_lower == 'momentum':
del opt_args['eps']
opt_args.pop('eps', None)
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
elif opt_lower == 'adam':
optimizer = optim.Adam(parameters, **opt_args)
@ -95,10 +95,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
elif opt_lower == 'nvnovograd':
optimizer = NvNovoGrad(parameters, **opt_args)
elif opt_lower == 'fusedsgd':
del opt_args['eps']
opt_args.pop('eps', None)
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
elif opt_lower == 'fusedmomentum':
del opt_args['eps']
opt_args.pop('eps', None)
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
elif opt_lower == 'fusedadam':
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)

Loading…
Cancel
Save