From e6f24e557843eb42268a43e163febb239d03ecb6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 25 Apr 2020 19:42:13 -0700 Subject: [PATCH] Add 'momentum' optimizer (SGD w/o nesterov) for stable EfficientDet training defaults --- timm/optim/optim_factory.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 553a6b6d..d97887d5 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -42,9 +42,12 @@ def create_optimizer(args, model, filter_bias_and_bn=True): opt_split = opt_lower.split('_') opt_lower = opt_split[-1] - if opt_lower == 'sgd': + if opt_lower == 'sgd' or opt_lower == 'nesterov': optimizer = optim.SGD( parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) + elif opt_lower == 'momentum': + optimizer = optim.SGD( + parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'adam': optimizer = optim.Adam( parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) @@ -75,6 +78,9 @@ def create_optimizer(args, model, filter_bias_and_bn=True): elif opt_lower == 'fusedsgd': optimizer = FusedSGD( parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True) + elif opt_lower == 'fusedmomentum': + optimizer = FusedSGD( + parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False) elif opt_lower == 'fusedadam': optimizer = FusedAdam( parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps)