Add lamb/lars to optim init imports, remove stray comment

pull/816/head
Ross Wightman 3 years ago
parent c207e02782
commit a16a753852

@ -1,7 +1,10 @@
from .adamp import AdamP from .adabelief import AdaBelief
from .adamw import AdamW
from .adafactor import Adafactor from .adafactor import Adafactor
from .adahessian import Adahessian from .adahessian import Adahessian
from .adamp import AdamP
from .adamw import AdamW
from .lamb import Lamb
from .lars import Lars
from .lookahead import Lookahead from .lookahead import Lookahead
from .madgrad import MADGRAD from .madgrad import MADGRAD
from .nadam import Nadam from .nadam import Nadam
@ -9,5 +12,4 @@ from .nvnovograd import NvNovoGrad
from .radam import RAdam from .radam import RAdam
from .rmsprop_tf import RMSpropTF from .rmsprop_tf import RMSpropTF
from .sgdp import SGDP from .sgdp import SGDP
from .adabelief import AdaBelief
from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs

@ -87,7 +87,6 @@ class Lars(Optimizer):
device = self.param_groups[0]['params'][0].device device = self.param_groups[0]['params'][0].device
one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
# exclude scaling for params with 0 weight decay
for group in self.param_groups: for group in self.param_groups:
weight_decay = group['weight_decay'] weight_decay = group['weight_decay']
momentum = group['momentum'] momentum = group['momentum']

Loading…
Cancel
Save