Fix #1184, scheduler noise bug during merge madness

pull/1014/head
Ross Wightman 3 years ago
parent 9440a50c95
commit 7cdd164d77

@ -92,6 +92,7 @@ class Scheduler:
def _is_apply_noise(self, t) -> bool: def _is_apply_noise(self, t) -> bool:
"""Return True if scheduler in noise range.""" """Return True if scheduler in noise range."""
apply_noise = False
if self.noise_range_t is not None: if self.noise_range_t is not None:
if isinstance(self.noise_range_t, (list, tuple)): if isinstance(self.noise_range_t, (list, tuple)):
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
@ -104,7 +105,7 @@ class Scheduler:
g.manual_seed(self.noise_seed + t) g.manual_seed(self.noise_seed + t)
if self.noise_type == 'normal': if self.noise_type == 'normal':
while True: while True:
# resample if noise out of percent limit, brute force but shouldn't spin much # resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item() noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct: if abs(noise) < self.noise_pct:
return noise return noise

Loading…
Cancel
Save