import torch from .scheduler import Scheduler class PlateauLRScheduler(Scheduler): """Decay the LR by a factor every time the validation loss plateaus.""" def __init__(self, optimizer, decay_rate=0.1, patience_t=10, verbose=True, threshold=1e-4, cooldown_t=0, warmup_t=0, warmup_lr_init=0, lr_min=0, mode='min', initialize=True, ): super().__init__(optimizer, 'lr', initialize=initialize) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, patience=patience_t, factor=decay_rate, verbose=verbose, threshold=threshold, cooldown=cooldown_t, mode=mode, min_lr=lr_min ) self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init if self.warmup_t: self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] super().update_groups(self.warmup_lr_init) else: self.warmup_steps = [1 for _ in self.base_values] def state_dict(self): return { 'best': self.lr_scheduler.best, 'last_epoch': self.lr_scheduler.last_epoch, } def load_state_dict(self, state_dict): self.lr_scheduler.best = state_dict['best'] if 'last_epoch' in state_dict: self.lr_scheduler.last_epoch = state_dict['last_epoch'] # override the base class step fn completely def step(self, epoch, metric=None): if epoch <= self.warmup_t: lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] super().update_groups(lrs) else: self.lr_scheduler.step(metric, epoch)