|
|
@ -1,6 +1,6 @@
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from torch import optim as optim
|
|
|
|
from torch import optim as optim
|
|
|
|
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead
|
|
|
|
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead, AdamP
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
|
|
|
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
|
|
|
|
has_apex = True
|
|
|
|
has_apex = True
|
|
|
@ -60,6 +60,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
|
|
|
elif opt_lower == 'radam':
|
|
|
|
elif opt_lower == 'radam':
|
|
|
|
optimizer = RAdam(
|
|
|
|
optimizer = RAdam(
|
|
|
|
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
|
|
|
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
|
|
|
|
|
|
|
elif opt_lower == 'adamp':
|
|
|
|
|
|
|
|
optimizer = AdamP(
|
|
|
|
|
|
|
|
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps,
|
|
|
|
|
|
|
|
delta=0.1, wd_ratio=0.01, nesterov=True)
|
|
|
|
elif opt_lower == 'adadelta':
|
|
|
|
elif opt_lower == 'adadelta':
|
|
|
|
optimizer = optim.Adadelta(
|
|
|
|
optimizer = optim.Adadelta(
|
|
|
|
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
|
|
|
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
|
|
|