Update optim_factory.py

Add `adamp` optimizer
pull/195/head
Sangdoo Yun 5 years ago committed by GitHub
parent 629b35cbcb
commit 4fae5f4c0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
import torch import torch
from torch import optim as optim from torch import optim as optim
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead, AdamP
try: try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True has_apex = True
@ -60,6 +60,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
elif opt_lower == 'radam': elif opt_lower == 'radam':
optimizer = RAdam( optimizer = RAdam(
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
elif opt_lower == 'adamp':
optimizer = AdamP(
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps,
delta=0.1, wd_ratio=0.01, nesterov=True)
elif opt_lower == 'adadelta': elif opt_lower == 'adadelta':
optimizer = optim.Adadelta( optimizer = optim.Adadelta(
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)

Loading…
Cancel
Save