Merge pull request #1112 from ayasyrev/sched_noise_dup_code

sched noise dup code remove
pull/1014/head
Ross Wightman 3 years ago committed by GitHub
commit d757fecaac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,7 +43,7 @@ class PlateauLRScheduler(Scheduler):
min_lr=lr_min min_lr=lr_min
) )
self.noise_range = noise_range_t self.noise_range_t = noise_range_t
self.noise_pct = noise_pct self.noise_pct = noise_pct
self.noise_type = noise_type self.noise_type = noise_type
self.noise_std = noise_std self.noise_std = noise_std
@ -82,25 +82,12 @@ class PlateauLRScheduler(Scheduler):
self.lr_scheduler.step(metric, epoch) # step the base scheduler self.lr_scheduler.step(metric, epoch) # step the base scheduler
if self.noise_range is not None: if self._is_apply_noise(epoch):
if isinstance(self.noise_range, (list, tuple)): self._apply_noise(epoch)
apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
else:
apply_noise = epoch >= self.noise_range
if apply_noise:
self._apply_noise(epoch)
def _apply_noise(self, epoch): def _apply_noise(self, epoch):
g = torch.Generator() noise = self._calculate_noise(epoch)
g.manual_seed(self.noise_seed + epoch)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
break
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
# apply the noise on top of previous LR, cache the old value so we can restore for normal # apply the noise on top of previous LR, cache the old value so we can restore for normal
# stepping of base scheduler # stepping of base scheduler

@ -85,21 +85,29 @@ class Scheduler:
param_group[self.param_group_field] = value param_group[self.param_group_field] = value
def _add_noise(self, lrs, t): def _add_noise(self, lrs, t):
if self._is_apply_noise(t):
noise = self._calculate_noise(t)
lrs = [v + v * noise for v in lrs]
return lrs
def _is_apply_noise(self, t) -> bool:
"""Return True if scheduler in noise range."""
if self.noise_range_t is not None: if self.noise_range_t is not None:
if isinstance(self.noise_range_t, (list, tuple)): if isinstance(self.noise_range_t, (list, tuple)):
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
else: else:
apply_noise = t >= self.noise_range_t apply_noise = t >= self.noise_range_t
if apply_noise: return apply_noise
g = torch.Generator()
g.manual_seed(self.noise_seed + t) def _calculate_noise(self, t) -> float:
if self.noise_type == 'normal': g = torch.Generator()
while True: g.manual_seed(self.noise_seed + t)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much # resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item() noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct: if abs(noise) < self.noise_pct:
break return noise
else: else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
lrs = [v + v * noise for v in lrs] return noise
return lrs

Loading…
Cancel
Save