diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py index 4f2cacb6..fbfc531f 100644 --- a/timm/scheduler/plateau_lr.py +++ b/timm/scheduler/plateau_lr.py @@ -43,7 +43,7 @@ class PlateauLRScheduler(Scheduler): min_lr=lr_min ) - self.noise_range = noise_range_t + self.noise_range_t = noise_range_t self.noise_pct = noise_pct self.noise_type = noise_type self.noise_std = noise_std @@ -82,25 +82,12 @@ class PlateauLRScheduler(Scheduler): self.lr_scheduler.step(metric, epoch) # step the base scheduler - if self.noise_range is not None: - if isinstance(self.noise_range, (list, tuple)): - apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] - else: - apply_noise = epoch >= self.noise_range - if apply_noise: - self._apply_noise(epoch) + if self._is_apply_noise(epoch): + self._apply_noise(epoch) + def _apply_noise(self, epoch): - g = torch.Generator() - g.manual_seed(self.noise_seed + epoch) - if self.noise_type == 'normal': - while True: - # resample if noise out of percent limit, brute force but shouldn't spin much - noise = torch.randn(1, generator=g).item() - if abs(noise) < self.noise_pct: - break - else: - noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + noise = self._calculate_noise(epoch) # apply the noise on top of previous LR, cache the old value so we can restore for normal # stepping of base scheduler diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 21d51509..81af76f9 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -85,21 +85,29 @@ class Scheduler: param_group[self.param_group_field] = value def _add_noise(self, lrs, t): + if self._is_apply_noise(t): + noise = self._calculate_noise(t) + lrs = [v + v * noise for v in lrs] + return lrs + + def _is_apply_noise(self, t) -> bool: + """Return True if scheduler in noise range.""" if self.noise_range_t is not None: if isinstance(self.noise_range_t, (list, tuple)): apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] else: apply_noise = t >= self.noise_range_t - if apply_noise: - g = torch.Generator() - g.manual_seed(self.noise_seed + t) - if self.noise_type == 'normal': - while True: + return apply_noise + + def _calculate_noise(self, t) -> float: + g = torch.Generator() + g.manual_seed(self.noise_seed + t) + if self.noise_type == 'normal': + while True: # resample if noise out of percent limit, brute force but shouldn't spin much - noise = torch.randn(1, generator=g).item() - if abs(noise) < self.noise_pct: - break - else: - noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct - lrs = [v + v * noise for v in lrs] - return lrs + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + return noise + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + return noise