diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index 80d37b37..dca8a580 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -33,11 +33,18 @@ def create_scheduler(args, optimizer): ) num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs elif args.sched == 'step': + if isinstance(args.lr_noise, (list, tuple)): + noise_range = [n * num_epochs for n in args.lr_noise] + else: + noise_range = args.lr_noise * num_epochs + print(noise_range) lr_scheduler = StepLRScheduler( optimizer, decay_t=args.decay_epochs, decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, + noise_range_t=noise_range, + noise_std=args.lr_noise_std, ) return lr_scheduler, num_epochs diff --git a/timm/scheduler/step_lr.py b/timm/scheduler/step_lr.py index 5ee8b90f..d3060fd8 100644 --- a/timm/scheduler/step_lr.py +++ b/timm/scheduler/step_lr.py @@ -14,14 +14,19 @@ class StepLRScheduler(Scheduler): decay_rate: float = 1., warmup_t=0, warmup_lr_init=0, + noise_range_t=None, + noise_std=1.0, t_in_epochs=True, - initialize=True) -> None: + initialize=True, + ) -> None: super().__init__(optimizer, param_group_field="lr", initialize=initialize) self.decay_t = decay_t self.decay_rate = decay_rate self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init + self.noise_range_t = noise_range_t + self.noise_std = noise_std self.t_in_epochs = t_in_epochs if self.warmup_t: self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] @@ -33,8 +38,18 @@ class StepLRScheduler(Scheduler): if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: - lrs = [v * (self.decay_rate ** (t // self.decay_t)) - for v in self.base_values] + lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] + if self.noise_range_t is not None: + if isinstance(self.noise_range_t, (list, tuple)): + apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] + else: + apply_noise = t >= self.noise_range_t + if apply_noise: + g = torch.Generator() + g.manual_seed(t) + lr_mult = torch.randn(1, generator=g).item() * self.noise_std + 1. + lrs = [min(5 * v, max(v / 5, v * lr_mult)) for v in lrs] + print(lrs) return lrs def get_epoch_values(self, epoch: int): diff --git a/train.py b/train.py index 7b4e1af0..c9d73833 100755 --- a/train.py +++ b/train.py @@ -105,6 +105,10 @@ parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', help='LR scheduler (default: "step"') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') +parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', + help='learning rate noise on/off epoch percentages') +parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', + help='learning rate nose std-dev (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', help='warmup learning rate (default: 0.0001)') parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',