Cycle limit on tanh sched

pull/1/head
Ross Wightman 6 years ago
parent 824f42e75e
commit 48360625f2

@ -25,6 +25,7 @@ class TanhLRScheduler(Scheduler):
decay_rate: float = 1., decay_rate: float = 1.,
warmup_updates=0, warmup_updates=0,
warmup_lr_init=0, warmup_lr_init=0,
cycle_limit=0,
initialize=True) -> None: initialize=True) -> None:
super().__init__(optimizer, param_group_field="lr", initialize=initialize) super().__init__(optimizer, param_group_field="lr", initialize=initialize)
@ -36,7 +37,7 @@ class TanhLRScheduler(Scheduler):
self.t_mul = t_mul self.t_mul = t_mul
self.lr_min = lr_min self.lr_min = lr_min
self.decay_rate = decay_rate self.decay_rate = decay_rate
self.cycle_limit = 0 self.cycle_limit = cycle_limit
self.warmup_updates = warmup_updates self.warmup_updates = warmup_updates
self.warmup_lr_init = warmup_lr_init self.warmup_lr_init = warmup_lr_init
if self.warmup_updates: if self.warmup_updates:
@ -65,7 +66,7 @@ class TanhLRScheduler(Scheduler):
t_i = self.t_initial t_i = self.t_initial
t_curr = curr_updates - (self.t_initial * i) t_curr = curr_updates - (self.t_initial * i)
if self.cycle_limit == 0 or i <= self.cycle_limit: if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
gamma = self.decay_rate ** i gamma = self.decay_rate ** i
lr_min = self.lr_min * gamma lr_min = self.lr_min * gamma
lr_max_values = [v * gamma for v in self.base_values] lr_max_values = [v * gamma for v in self.base_values]

Loading…
Cancel
Save