From cd3dc4979f6ca16a09910b4a32b7a8f07cc31fda Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 12 Apr 2021 08:25:31 -0700 Subject: [PATCH] Fix adabelief imports, remove prints, preserve memory format is the default arg for zeros_like --- timm/optim/adabelief.py | 65 ++++++++----------------------------- timm/optim/optim_factory.py | 2 +- 2 files changed, 14 insertions(+), 53 deletions(-) diff --git a/timm/optim/adabelief.py b/timm/optim/adabelief.py index 71075524..a26d7b27 100644 --- a/timm/optim/adabelief.py +++ b/timm/optim/adabelief.py @@ -1,13 +1,11 @@ import math import torch from torch.optim.optimizer import Optimizer -from tabulate import tabulate -from colorama import Fore, Back, Style -version_higher = ( torch.__version__ >= "1.5.0" ) class AdaBelief(Optimizer): r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch + Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups @@ -33,39 +31,17 @@ class AdaBelief(Optimizer): update similar to RAdam degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update when variance of gradient is high - print_change_log (boolean, optional) (default: True) If set as True, print the modifcation to - default hyper-parameters reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020 + + For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer' + For example train/args for EfficientNet see these gists + - link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037 + - link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3 """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False, weight_decouple=True, fixed_decay=False, rectify=True, - degenerated_to_sgd=True, print_change_log = True): - - # ------------------------------------------------------------------------------ - # Print modifications to default arguments - if print_change_log: - print(Fore.RED + 'Please check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.') - print(Fore.RED + 'Modifications to default arguments:') - default_table = tabulate([ - ['adabelief-pytorch=0.0.5','1e-8','False','False'], - ['>=0.1.0 (Current 0.2.0)','1e-16','True','True']], - headers=['eps','weight_decouple','rectify']) - print(Fore.RED + default_table) - - recommend_table = tabulate([ - ['Recommended eps = 1e-8', 'Recommended eps = 1e-16'], - ], - headers=['SGD better than Adam (e.g. CNN for Image Classification)','Adam better than SGD (e.g. Transformer, GAN)']) - print(Fore.BLUE + recommend_table) - - print(Fore.BLUE +'For a complete table of recommended hyperparameters, see') - print(Fore.BLUE + 'https://github.com/juntang-zhuang/Adabelief-Optimizer') - - print(Fore.GREEN + 'You can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.') - - print(Style.RESET_ALL) - # ------------------------------------------------------------------------------ + degenerated_to_sgd=True): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -90,14 +66,6 @@ class AdaBelief(Optimizer): self.weight_decouple = weight_decouple self.rectify = rectify self.fixed_decay = fixed_decay - if self.weight_decouple: - print('Weight decoupling enabled in AdaBelief') - if self.fixed_decay: - print('Weight decay fixed') - if self.rectify: - print('Rectification enabled in AdaBelief') - if amsgrad: - print('AMSGrad enabled in AdaBelief') def __setstate__(self, state): super(AdaBelief, self).__setstate__(state) @@ -113,17 +81,13 @@ class AdaBelief(Optimizer): # State initialization state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \ - if version_higher else torch.zeros_like(p.data) + state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values - state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \ - if version_higher else torch.zeros_like(p.data) - + state['exp_avg_var'] = torch.zeros_like(p.data) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \ - if version_higher else torch.zeros_like(p.data) + state['max_exp_avg_var'] = torch.zeros_like(p.data) def step(self, closure=None): """Performs a single optimization step. @@ -161,15 +125,12 @@ class AdaBelief(Optimizer): if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \ - if version_higher else torch.zeros_like(p.data) + state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values - state['exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \ - if version_higher else torch.zeros_like(p.data) + state['exp_avg_var'] = torch.zeros_like(p.data) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_var'] = torch.zeros_like(p.data,memory_format=torch.preserve_format) \ - if version_higher else torch.zeros_like(p.data) + state['max_exp_avg_var'] = torch.zeros_like(p.data) # perform weight decay, check if decoupled weight decay if self.weight_decouple: diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index c9b3b6df..2017d21f 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -121,7 +121,7 @@ def create_optimizer_v2( elif opt_lower == 'adam': optimizer = optim.Adam(parameters, **opt_args) elif opt_lower == 'adabelief': - optimizer = AdaBelief(parameters, rectify = False, print_change_log = False,**opt_args) + optimizer = AdaBelief(parameters, rectify=False, **opt_args) elif opt_lower == 'adamw': optimizer = optim.AdamW(parameters, **opt_args) elif opt_lower == 'nadam':