|
|
@ -42,9 +42,12 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
|
|
|
|
|
|
|
|
|
|
|
opt_split = opt_lower.split('_')
|
|
|
|
opt_split = opt_lower.split('_')
|
|
|
|
opt_lower = opt_split[-1]
|
|
|
|
opt_lower = opt_split[-1]
|
|
|
|
if opt_lower == 'sgd':
|
|
|
|
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
|
|
|
optimizer = optim.SGD(
|
|
|
|
optimizer = optim.SGD(
|
|
|
|
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
|
|
|
|
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
|
|
|
|
|
|
|
|
elif opt_lower == 'momentum':
|
|
|
|
|
|
|
|
optimizer = optim.SGD(
|
|
|
|
|
|
|
|
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False)
|
|
|
|
elif opt_lower == 'adam':
|
|
|
|
elif opt_lower == 'adam':
|
|
|
|
optimizer = optim.Adam(
|
|
|
|
optimizer = optim.Adam(
|
|
|
|
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
|
|
|
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
|
|
|
@ -75,6 +78,9 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
|
|
|
elif opt_lower == 'fusedsgd':
|
|
|
|
elif opt_lower == 'fusedsgd':
|
|
|
|
optimizer = FusedSGD(
|
|
|
|
optimizer = FusedSGD(
|
|
|
|
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
|
|
|
|
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=True)
|
|
|
|
|
|
|
|
elif opt_lower == 'fusedmomentum':
|
|
|
|
|
|
|
|
optimizer = FusedSGD(
|
|
|
|
|
|
|
|
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, nesterov=False)
|
|
|
|
elif opt_lower == 'fusedadam':
|
|
|
|
elif opt_lower == 'fusedadam':
|
|
|
|
optimizer = FusedAdam(
|
|
|
|
optimizer = FusedAdam(
|
|
|
|
parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps)
|
|
|
|
parameters, lr=args.lr, adam_w_mode=False, weight_decay=weight_decay, eps=args.opt_eps)
|
|
|
|