diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index f3a6deb0..89d01b03 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -1,6 +1,6 @@ import torch from torch import optim as optim -from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead, AdamP +from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead, AdamP, SGDP try: from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD has_apex = True @@ -63,7 +63,11 @@ def create_optimizer(args, model, filter_bias_and_bn=True): elif opt_lower == 'adamp': optimizer = AdamP( parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps, - delta=0.1, wd_ratio=0.01, nesterov=True) + delta=0.1, wd_ratio=0.01, nesterov=True) + elif opt_lower == 'sdgp': + optimizer = SGDP( + parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, + eps=args.opt_eps, nesterov=True) elif opt_lower == 'adadelta': optimizer = optim.Adadelta( parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)