From 48360625f28fb3fb375b85a28a6a78e44a37bfd4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 8 Feb 2019 21:47:27 -0800 Subject: [PATCH] Cycle limit on tanh sched --- scheduler/tanh_lr.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scheduler/tanh_lr.py b/scheduler/tanh_lr.py index a8a67777..fbb6ccf4 100644 --- a/scheduler/tanh_lr.py +++ b/scheduler/tanh_lr.py @@ -25,6 +25,7 @@ class TanhLRScheduler(Scheduler): decay_rate: float = 1., warmup_updates=0, warmup_lr_init=0, + cycle_limit=0, initialize=True) -> None: super().__init__(optimizer, param_group_field="lr", initialize=initialize) @@ -36,7 +37,7 @@ class TanhLRScheduler(Scheduler): self.t_mul = t_mul self.lr_min = lr_min self.decay_rate = decay_rate - self.cycle_limit = 0 + self.cycle_limit = cycle_limit self.warmup_updates = warmup_updates self.warmup_lr_init = warmup_lr_init if self.warmup_updates: @@ -65,7 +66,7 @@ class TanhLRScheduler(Scheduler): t_i = self.t_initial t_curr = curr_updates - (self.t_initial * i) - if self.cycle_limit == 0 or i <= self.cycle_limit: + if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): gamma = self.decay_rate ** i lr_min = self.lr_min * gamma lr_max_values = [v * gamma for v in self.base_values]