Revamp LR noise, move logic to scheduler base. Fixup PlateauLRScheduler and add it as an option.

pull/94/head
Ross Wightman 5 years ago
parent 514b0938c4
commit 27b3680d49

@ -29,8 +29,15 @@ class CosineLRScheduler(Scheduler):
warmup_prefix=False,
cycle_limit=0,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True) -> None:
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
assert t_initial > 0
assert lr_min >= 0

@ -8,33 +8,34 @@ class PlateauLRScheduler(Scheduler):
def __init__(self,
optimizer,
factor=0.1,
patience=10,
verbose=False,
decay_rate=0.1,
patience_t=10,
verbose=True,
threshold=1e-4,
cooldown_epochs=0,
warmup_updates=0,
cooldown_t=0,
warmup_t=0,
warmup_lr_init=0,
lr_min=0,
mode='min',
initialize=True,
):
super().__init__(optimizer, 'lr', initialize=False)
super().__init__(optimizer, 'lr', initialize=initialize)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer.optimizer,
patience=patience,
factor=factor,
self.optimizer,
patience=patience_t,
factor=decay_rate,
verbose=verbose,
threshold=threshold,
cooldown=cooldown_epochs,
cooldown=cooldown_t,
mode=mode,
min_lr=lr_min
)
self.warmup_updates = warmup_updates
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
if self.warmup_updates:
self.warmup_active = warmup_updates > 0 # this state updates with num_updates
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
@ -51,18 +52,9 @@ class PlateauLRScheduler(Scheduler):
self.lr_scheduler.last_epoch = state_dict['last_epoch']
# override the base class step fn completely
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
if val_loss is not None and not self.warmup_active:
self.lr_scheduler.step(val_loss, epoch)
else:
self.lr_scheduler.last_epoch = epoch
def get_update_values(self, num_updates: int):
if num_updates < self.warmup_updates:
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
def step(self, epoch, metric=None):
if epoch <= self.warmup_t:
lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
super().update_groups(lrs)
else:
self.warmup_active = False # warmup cancelled by first update past warmup_update count
lrs = None # no change on update after warmup stage
return lrs
self.lr_scheduler.step(metric, epoch)

@ -25,6 +25,11 @@ class Scheduler:
def __init__(self,
optimizer: torch.optim.Optimizer,
param_group_field: str,
noise_range_t=None,
noise_type='normal',
noise_pct=0.67,
noise_std=1.0,
noise_seed=None,
initialize: bool = True) -> None:
self.optimizer = optimizer
self.param_group_field = param_group_field
@ -40,6 +45,11 @@ class Scheduler:
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
self.metric = None # any point to having this for all?
self.noise_range_t = noise_range_t
self.noise_pct = noise_pct
self.noise_type = noise_type
self.noise_std = noise_std
self.noise_seed = noise_seed if noise_seed is not None else 42
self.update_groups(self.base_values)
def state_dict(self) -> Dict[str, Any]:
@ -58,12 +68,14 @@ class Scheduler:
self.metric = metric
values = self.get_epoch_values(epoch)
if values is not None:
values = self._add_noise(values, epoch)
self.update_groups(values)
def step_update(self, num_updates: int, metric: float = None):
self.metric = metric
values = self.get_update_values(num_updates)
if values is not None:
values = self._add_noise(values, num_updates)
self.update_groups(values)
def update_groups(self, values):
@ -71,3 +83,23 @@ class Scheduler:
values = [values] * len(self.optimizer.param_groups)
for param_group, value in zip(self.optimizer.param_groups, values):
param_group[self.param_group_field] = value
def _add_noise(self, lrs, t):
if self.noise_range_t is not None:
if isinstance(self.noise_range_t, (list, tuple)):
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
else:
apply_noise = t >= self.noise_range_t
if apply_noise:
g = torch.Generator()
g.manual_seed(self.noise_seed + t)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
break
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
lrs = [v + v * noise for v in lrs]
return lrs

@ -1,10 +1,21 @@
from .cosine_lr import CosineLRScheduler
from .tanh_lr import TanhLRScheduler
from .step_lr import StepLRScheduler
from .plateau_lr import PlateauLRScheduler
def create_scheduler(args, optimizer):
num_epochs = args.epochs
if args.lr_noise is not None:
if isinstance(args.lr_noise, (list, tuple)):
noise_range = [n * num_epochs for n in args.lr_noise]
else:
noise_range = args.lr_noise * num_epochs
print('Noise range:', noise_range)
else:
noise_range = None
lr_scheduler = None
#FIXME expose cycle parms of the scheduler config to arguments
if args.sched == 'cosine':
@ -18,6 +29,10 @@ def create_scheduler(args, optimizer):
warmup_t=args.warmup_epochs,
cycle_limit=1,
t_in_epochs=True,
noise_range_t=noise_range,
noise_pct=args.lr_noise_pct,
noise_std=args.lr_noise_std,
noise_seed=args.seed,
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
elif args.sched == 'tanh':
@ -30,14 +45,13 @@ def create_scheduler(args, optimizer):
warmup_t=args.warmup_epochs,
cycle_limit=1,
t_in_epochs=True,
noise_range_t=noise_range,
noise_pct=args.lr_noise_pct,
noise_std=args.lr_noise_std,
noise_seed=args.seed,
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
elif args.sched == 'step':
if isinstance(args.lr_noise, (list, tuple)):
noise_range = [n * num_epochs for n in args.lr_noise]
else:
noise_range = args.lr_noise * num_epochs
print(noise_range)
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=args.decay_epochs,
@ -45,6 +59,19 @@ def create_scheduler(args, optimizer):
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
noise_range_t=noise_range,
noise_pct=args.lr_noise_pct,
noise_std=args.lr_noise_std,
noise_seed=args.seed,
)
elif args.sched == 'plateau':
lr_scheduler = PlateauLRScheduler(
optimizer,
decay_rate=args.decay_rate,
patience_t=args.patience_epochs,
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cooldown_t=args.cooldown_epochs,
)
return lr_scheduler, num_epochs

@ -10,23 +10,26 @@ class StepLRScheduler(Scheduler):
def __init__(self,
optimizer: torch.optim.Optimizer,
decay_t: int,
decay_t: float,
decay_rate: float = 1.,
warmup_t=0,
warmup_lr_init=0,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
t_in_epochs=True,
noise_seed=42,
initialize=True,
) -> None:
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
self.decay_t = decay_t
self.decay_rate = decay_rate
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.noise_range_t = noise_range_t
self.noise_std = noise_std
self.t_in_epochs = t_in_epochs
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
@ -39,17 +42,6 @@ class StepLRScheduler(Scheduler):
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
if self.noise_range_t is not None:
if isinstance(self.noise_range_t, (list, tuple)):
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
else:
apply_noise = t >= self.noise_range_t
if apply_noise:
g = torch.Generator()
g.manual_seed(t)
lr_mult = torch.randn(1, generator=g).item() * self.noise_std + 1.
lrs = [min(5 * v, max(v / 5, v * lr_mult)) for v in lrs]
print(lrs)
return lrs
def get_epoch_values(self, epoch: int):

@ -28,8 +28,15 @@ class TanhLRScheduler(Scheduler):
warmup_prefix=False,
cycle_limit=0,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
initialize=True) -> None:
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
initialize=initialize)
assert t_initial > 0
assert lr_min >= 0

@ -107,8 +107,10 @@ parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate nose std-dev (default: 1.0)')
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
@ -123,6 +125,8 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Augmentation parameters

Loading…
Cancel
Save