From a16a7538529e8f0e196257a708852ab9ea6ff997 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Aug 2021 22:55:02 -0700 Subject: [PATCH] Add lamb/lars to optim init imports, remove stray comment --- timm/optim/__init__.py | 8 +++++--- timm/optim/lars.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index 2df9fb12..7ee4958e 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -1,7 +1,10 @@ -from .adamp import AdamP -from .adamw import AdamW +from .adabelief import AdaBelief from .adafactor import Adafactor from .adahessian import Adahessian +from .adamp import AdamP +from .adamw import AdamW +from .lamb import Lamb +from .lars import Lars from .lookahead import Lookahead from .madgrad import MADGRAD from .nadam import Nadam @@ -9,5 +12,4 @@ from .nvnovograd import NvNovoGrad from .radam import RAdam from .rmsprop_tf import RMSpropTF from .sgdp import SGDP -from .adabelief import AdaBelief from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs diff --git a/timm/optim/lars.py b/timm/optim/lars.py index 958c2d0e..98198e67 100644 --- a/timm/optim/lars.py +++ b/timm/optim/lars.py @@ -87,7 +87,6 @@ class Lars(Optimizer): device = self.param_groups[0]['params'][0].device one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly - # exclude scaling for params with 0 weight decay for group in self.param_groups: weight_decay = group['weight_decay'] momentum = group['momentum']