Fix adabelief imports, remove prints, preserve memory format is the default arg for zeros_like

pull/556/head
Ross Wightman 4 years ago
parent 21812d33aa
commit cd3dc4979f

@ -1,13 +1,11 @@
import math import math
import torch import torch
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from tabulate import tabulate
from colorama import Fore, Back, Style
version_higher = ( torch.__version__ >= "1.5.0" )
class AdaBelief(Optimizer): class AdaBelief(Optimizer):
r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
Arguments: Arguments:
params (iterable): iterable of parameters to optimize or dicts defining params (iterable): iterable of parameters to optimize or dicts defining
parameter groups parameter groups
@ -33,39 +31,17 @@ class AdaBelief(Optimizer):
update similar to RAdam update similar to RAdam
degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
when variance of gradient is high when variance of gradient is high
print_change_log (boolean, optional) (default: True) If set as True, print the modifcation to
default hyper-parameters
reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020 reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer'
For example train/args for EfficientNet see these gists
- link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037
- link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3
""" """
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16,
weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True, weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True,
degenerated_to_sgd=True, print_change_log = True): degenerated_to_sgd=True):
# ------------------------------------------------------------------------------
# Print modifications to default arguments
if print_change_log:
print(Fore.RED + 'Please check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.')
print(Fore.RED + 'Modifications to default arguments:')
default_table = tabulate([
['adabelief-pytorch=0.0.5','1e-8','False','False'],
['>=0.1.0 (Current 0.2.0)','1e-16','True','True']],
headers=['eps','weight_decouple','rectify'])
print(Fore.RED + default_table)
recommend_table = tabulate([
['Recommended eps = 1e-8', 'Recommended eps = 1e-16'],
],
headers=['SGD better than Adam (e.g. CNN for Image Classification)','Adam better than SGD (e.g. Transformer, GAN)'])
print(Fore.BLUE + recommend_table)
print(Fore.BLUE +'For a complete table of recommended hyperparameters, see')
print(Fore.BLUE + 'https://github.com/juntang-zhuang/Adabelief-Optimizer')
print(Fore.GREEN + 'You can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.')
print(Style.RESET_ALL)
# ------------------------------------------------------------------------------
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
@ -90,14 +66,6 @@ class AdaBelief(Optimizer):
self.weight_decouple = weight_decouple self.weight_decouple = weight_decouple
self.rectify = rectify self.rectify = rectify
self.fixed_decay = fixed_decay self.fixed_decay = fixed_decay
if self.weight_decouple:
print('Weight decoupling enabled in AdaBelief')
if self.fixed_decay:
print('Weight decay fixed')
if self.rectify:
print('Rectification enabled in AdaBelief')
if amsgrad:
print('AMSGrad enabled in AdaBelief')
def __setstate__(self, state): def __setstate__(self, state):
super(AdaBelief, self).__setstate__(state) super(AdaBelief, self).__setstate__(state)
@ -113,17 +81,13 @@ class AdaBelief(Optimizer):
# State initialization # State initialization
state['step'] = 0 state['step'] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \ state['exp_avg'] = torch.zeros_like(p.data)
if version_higher else torch.zeros_like(p.data)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \ state['exp_avg_var'] = torch.zeros_like(p.data)
if version_higher else torch.zeros_like(p.data)
if amsgrad: if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \ state['max_exp_avg_var'] = torch.zeros_like(p.data)
if version_higher else torch.zeros_like(p.data)
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
@ -161,15 +125,12 @@ class AdaBelief(Optimizer):
if len(state) == 0: if len(state) == 0:
state['step'] = 0 state['step'] = 0
# Exponential moving average of gradient values # Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \ state['exp_avg'] = torch.zeros_like(p.data)
if version_higher else torch.zeros_like(p.data)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \ state['exp_avg_var'] = torch.zeros_like(p.data)
if version_higher else torch.zeros_like(p.data)
if amsgrad: if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \ state['max_exp_avg_var'] = torch.zeros_like(p.data)
if version_higher else torch.zeros_like(p.data)
# perform weight decay, check if decoupled weight decay # perform weight decay, check if decoupled weight decay
if self.weight_decouple: if self.weight_decouple:

@ -121,7 +121,7 @@ def create_optimizer_v2(
elif opt_lower == 'adam': elif opt_lower == 'adam':
optimizer = optim.Adam(parameters, **opt_args) optimizer = optim.Adam(parameters, **opt_args)
elif opt_lower == 'adabelief': elif opt_lower == 'adabelief':
optimizer = AdaBelief(parameters, rectify = False, print_change_log = False,**opt_args) optimizer = AdaBelief(parameters, rectify=False, **opt_args)
elif opt_lower == 'adamw': elif opt_lower == 'adamw':
optimizer = optim.AdamW(parameters, **opt_args) optimizer = optim.AdamW(parameters, **opt_args)
elif opt_lower == 'nadam': elif opt_lower == 'nadam':

Loading…
Cancel
Save