|
|
|
@ -1,13 +1,11 @@
|
|
|
|
|
import math
|
|
|
|
|
import torch
|
|
|
|
|
from torch.optim.optimizer import Optimizer
|
|
|
|
|
from tabulate import tabulate
|
|
|
|
|
from colorama import Fore, Back, Style
|
|
|
|
|
|
|
|
|
|
version_higher = ( torch.__version__ >= "1.5.0" )
|
|
|
|
|
|
|
|
|
|
class AdaBelief(Optimizer):
|
|
|
|
|
r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
|
|
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
|
|
|
parameter groups
|
|
|
|
@ -33,39 +31,17 @@ class AdaBelief(Optimizer):
|
|
|
|
|
update similar to RAdam
|
|
|
|
|
degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
|
|
|
|
|
when variance of gradient is high
|
|
|
|
|
print_change_log (boolean, optional) (default: True) If set as True, print the modifcation to
|
|
|
|
|
default hyper-parameters
|
|
|
|
|
reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
|
|
|
|
|
|
|
|
|
|
For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer'
|
|
|
|
|
For example train/args for EfficientNet see these gists
|
|
|
|
|
- link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037
|
|
|
|
|
- link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
|
|
|
|
|
weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True,
|
|
|
|
|
degenerated_to_sgd=True, print_change_log = True):
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
# Print modifications to default arguments
|
|
|
|
|
if print_change_log:
|
|
|
|
|
print(Fore.RED + 'Please check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.')
|
|
|
|
|
print(Fore.RED + 'Modifications to default arguments:')
|
|
|
|
|
default_table = tabulate([
|
|
|
|
|
['adabelief-pytorch=0.0.5','1e-8','False','False'],
|
|
|
|
|
['>=0.1.0 (Current 0.2.0)','1e-16','True','True']],
|
|
|
|
|
headers=['eps','weight_decouple','rectify'])
|
|
|
|
|
print(Fore.RED + default_table)
|
|
|
|
|
|
|
|
|
|
recommend_table = tabulate([
|
|
|
|
|
['Recommended eps = 1e-8', 'Recommended eps = 1e-16'],
|
|
|
|
|
],
|
|
|
|
|
headers=['SGD better than Adam (e.g. CNN for Image Classification)','Adam better than SGD (e.g. Transformer, GAN)'])
|
|
|
|
|
print(Fore.BLUE + recommend_table)
|
|
|
|
|
|
|
|
|
|
print(Fore.BLUE +'For a complete table of recommended hyperparameters, see')
|
|
|
|
|
print(Fore.BLUE + 'https://github.com/juntang-zhuang/Adabelief-Optimizer')
|
|
|
|
|
|
|
|
|
|
print(Fore.GREEN + 'You can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.')
|
|
|
|
|
|
|
|
|
|
print(Style.RESET_ALL)
|
|
|
|
|
# ------------------------------------------------------------------------------
|
|
|
|
|
degenerated_to_sgd=True):
|
|
|
|
|
|
|
|
|
|
if not 0.0 <= lr:
|
|
|
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
|
|
@ -90,14 +66,6 @@ class AdaBelief(Optimizer):
|
|
|
|
|
self.weight_decouple = weight_decouple
|
|
|
|
|
self.rectify = rectify
|
|
|
|
|
self.fixed_decay = fixed_decay
|
|
|
|
|
if self.weight_decouple:
|
|
|
|
|
print('Weight decoupling enabled in AdaBelief')
|
|
|
|
|
if self.fixed_decay:
|
|
|
|
|
print('Weight decay fixed')
|
|
|
|
|
if self.rectify:
|
|
|
|
|
print('Rectification enabled in AdaBelief')
|
|
|
|
|
if amsgrad:
|
|
|
|
|
print('AMSGrad enabled in AdaBelief')
|
|
|
|
|
|
|
|
|
|
def __setstate__(self, state):
|
|
|
|
|
super(AdaBelief, self).__setstate__(state)
|
|
|
|
@ -113,17 +81,13 @@ class AdaBelief(Optimizer):
|
|
|
|
|
# State initialization
|
|
|
|
|
state['step'] = 0
|
|
|
|
|
# Exponential moving average of gradient values
|
|
|
|
|
state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
|
|
|
|
|
if version_higher else torch.zeros_like(p.data)
|
|
|
|
|
state['exp_avg'] = torch.zeros_like(p.data)
|
|
|
|
|
|
|
|
|
|
# Exponential moving average of squared gradient values
|
|
|
|
|
state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
|
|
|
|
|
if version_higher else torch.zeros_like(p.data)
|
|
|
|
|
|
|
|
|
|
state['exp_avg_var'] = torch.zeros_like(p.data)
|
|
|
|
|
if amsgrad:
|
|
|
|
|
# Maintains max of all exp. moving avg. of sq. grad. values
|
|
|
|
|
state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
|
|
|
|
|
if version_higher else torch.zeros_like(p.data)
|
|
|
|
|
state['max_exp_avg_var'] = torch.zeros_like(p.data)
|
|
|
|
|
|
|
|
|
|
def step(self, closure=None):
|
|
|
|
|
"""Performs a single optimization step.
|
|
|
|
@ -161,15 +125,12 @@ class AdaBelief(Optimizer):
|
|
|
|
|
if len(state) == 0:
|
|
|
|
|
state['step'] = 0
|
|
|
|
|
# Exponential moving average of gradient values
|
|
|
|
|
state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
|
|
|
|
|
if version_higher else torch.zeros_like(p.data)
|
|
|
|
|
state['exp_avg'] = torch.zeros_like(p.data)
|
|
|
|
|
# Exponential moving average of squared gradient values
|
|
|
|
|
state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
|
|
|
|
|
if version_higher else torch.zeros_like(p.data)
|
|
|
|
|
state['exp_avg_var'] = torch.zeros_like(p.data)
|
|
|
|
|
if amsgrad:
|
|
|
|
|
# Maintains max of all exp. moving avg. of sq. grad. values
|
|
|
|
|
state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \
|
|
|
|
|
if version_higher else torch.zeros_like(p.data)
|
|
|
|
|
state['max_exp_avg_var'] = torch.zeros_like(p.data)
|
|
|
|
|
|
|
|
|
|
# perform weight decay, check if decoupled weight decay
|
|
|
|
|
if self.weight_decouple:
|
|
|
|
|