Experimenting with per-epoch learning rate noise w/ step scheduler

pull/94/head
Ross Wightman 5 years ago
parent d77f45a6f6
commit 514b0938c4

@ -33,11 +33,18 @@ def create_scheduler(args, optimizer):
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
elif args.sched == 'step':
if isinstance(args.lr_noise, (list, tuple)):
noise_range = [n * num_epochs for n in args.lr_noise]
else:
noise_range = args.lr_noise * num_epochs
print(noise_range)
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=args.decay_epochs,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
noise_range_t=noise_range,
noise_std=args.lr_noise_std,
)
return lr_scheduler, num_epochs

@ -14,14 +14,19 @@ class StepLRScheduler(Scheduler):
decay_rate: float = 1.,
warmup_t=0,
warmup_lr_init=0,
noise_range_t=None,
noise_std=1.0,
t_in_epochs=True,
initialize=True) -> None:
initialize=True,
) -> None:
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
self.decay_t = decay_t
self.decay_rate = decay_rate
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
self.noise_range_t = noise_range_t
self.noise_std = noise_std
self.t_in_epochs = t_in_epochs
if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
@ -33,8 +38,18 @@ class StepLRScheduler(Scheduler):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
lrs = [v * (self.decay_rate ** (t // self.decay_t))
for v in self.base_values]
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
if self.noise_range_t is not None:
if isinstance(self.noise_range_t, (list, tuple)):
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
else:
apply_noise = t >= self.noise_range_t
if apply_noise:
g = torch.Generator()
g.manual_seed(t)
lr_mult = torch.randn(1, generator=g).item() * self.noise_std + 1.
lrs = [min(5 * v, max(v / 5, v * lr_mult)) for v in lrs]
print(lrs)
return lrs
def get_epoch_values(self, epoch: int):

@ -105,6 +105,10 @@ parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate nose std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',

Loading…
Cancel
Save