diff --git a/timm/scheduler/multistep_lr.py b/timm/scheduler/multistep_lr.py new file mode 100644 index 00000000..a5d5fe19 --- /dev/null +++ b/timm/scheduler/multistep_lr.py @@ -0,0 +1,65 @@ +""" MultiStep LR Scheduler + +Basic multi step LR schedule with warmup, noise. +""" +import torch +import bisect +from timm.scheduler.scheduler import Scheduler +from typing import List + +class MultiStepLRScheduler(Scheduler): + """ + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + decay_t: List[int], + decay_rate: float = 1., + warmup_t=0, + warmup_lr_init=0, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + self.decay_t = decay_t + self.decay_rate = decay_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def get_curr_decay_steps(self, t): + # find where in the array t goes, + # assumes self.decay_t is sorted + return bisect.bisect_right(self.decay_t, t+1) + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index 9f7748f4..5f5a4ca8 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -5,6 +5,7 @@ from .cosine_lr import CosineLRScheduler from .tanh_lr import TanhLRScheduler from .step_lr import StepLRScheduler from .plateau_lr import PlateauLRScheduler +from .multistep_lr import MultiStepLRScheduler def create_scheduler(args, optimizer): @@ -67,6 +68,18 @@ def create_scheduler(args, optimizer): noise_std=getattr(args, 'lr_noise_std', 1.), noise_seed=getattr(args, 'seed', 42), ) + elif args.sched == 'multistep': + lr_scheduler = MultiStepLRScheduler( + optimizer, + decay_t=args.decay_epochs, + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) elif args.sched == 'plateau': mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' lr_scheduler = PlateauLRScheduler(