Update optim_factory.py

pull/195/head
Sangdoo Yun 5 years ago committed by GitHub
parent 7271e81cbb
commit 64857fee65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
import torch import torch
from torch import optim as optim from torch import optim as optim
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead, AdamP from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead, AdamP, SGDP
try: try:
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
has_apex = True has_apex = True
@ -64,6 +64,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
optimizer = AdamP( optimizer = AdamP(
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps, parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps,
delta=0.1, wd_ratio=0.01, nesterov=True) delta=0.1, wd_ratio=0.01, nesterov=True)
elif opt_lower == 'sdgp':
optimizer = SGDP(
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay,
eps=args.opt_eps, nesterov=True)
elif opt_lower == 'adadelta': elif opt_lower == 'adadelta':
optimizer = optim.Adadelta( optimizer = optim.Adadelta(
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps) parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)

Loading…
Cancel
Save