diff --git a/timm/scheduler/__init__.py b/timm/scheduler/__init__.py index 6a778982..f1961b88 100644 --- a/timm/scheduler/__init__.py +++ b/timm/scheduler/__init__.py @@ -1,5 +1,8 @@ from .cosine_lr import CosineLRScheduler +from .multistep_lr import MultiStepLRScheduler from .plateau_lr import PlateauLRScheduler +from .poly_lr import PolyLRScheduler from .step_lr import StepLRScheduler from .tanh_lr import TanhLRScheduler + from .scheduler_factory import create_scheduler diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py index 1532f092..84ee349e 100644 --- a/timm/scheduler/cosine_lr.py +++ b/timm/scheduler/cosine_lr.py @@ -1,8 +1,8 @@ """ Cosine Scheduler -Cosine LR schedule with warmup, cycle/restarts, noise. +Cosine LR schedule with warmup, cycle/restarts, noise, k-decay. -Hacked together by / Copyright 2020 Ross Wightman +Hacked together by / Copyright 2021 Ross Wightman """ import logging import math @@ -22,23 +22,26 @@ class CosineLRScheduler(Scheduler): Inspiration from https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py + + k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 """ def __init__(self, optimizer: torch.optim.Optimizer, t_initial: int, - t_mul: float = 1., lr_min: float = 0., - decay_rate: float = 1., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, warmup_t=0, warmup_lr_init=0, warmup_prefix=False, - cycle_limit=0, t_in_epochs=True, noise_range_t=None, noise_pct=0.67, noise_std=1.0, noise_seed=42, + k_decay=1.0, initialize=True) -> None: super().__init__( optimizer, param_group_field="lr", @@ -47,18 +50,19 @@ class CosineLRScheduler(Scheduler): assert t_initial > 0 assert lr_min >= 0 - if t_initial == 1 and t_mul == 1 and decay_rate == 1: + if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: _logger.warning("Cosine annealing scheduler will have no effect on the learning " "rate since t_initial = t_mul = eta_mul = 1.") self.t_initial = t_initial - self.t_mul = t_mul self.lr_min = lr_min - self.decay_rate = decay_rate + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay self.cycle_limit = cycle_limit self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init self.warmup_prefix = warmup_prefix self.t_in_epochs = t_in_epochs + self.k_decay = k_decay if self.warmup_t: self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] super().update_groups(self.warmup_lr_init) @@ -72,22 +76,23 @@ class CosineLRScheduler(Scheduler): if self.warmup_prefix: t = t - self.warmup_t - if self.t_mul != 1: - i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) - t_i = self.t_mul ** i * self.t_initial - t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial else: i = t // self.t_initial t_i = self.t_initial t_curr = t - (self.t_initial * i) - gamma = self.decay_rate ** i - lr_min = self.lr_min * gamma + gamma = self.cycle_decay ** i lr_max_values = [v * gamma for v in self.base_values] + k = self.k_decay - if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): + if i < self.cycle_limit: lrs = [ - lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values + self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k)) + for lr_max in lr_max_values ] else: lrs = [self.lr_min for _ in self.base_values] @@ -107,10 +112,8 @@ class CosineLRScheduler(Scheduler): return None def get_cycle_length(self, cycles=0): - if not cycles: - cycles = self.cycle_limit - cycles = max(1, cycles) - if self.t_mul == 1.0: + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: return self.t_initial * cycles else: - return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) diff --git a/timm/scheduler/poly_lr.py b/timm/scheduler/poly_lr.py new file mode 100644 index 00000000..0c1e63b7 --- /dev/null +++ b/timm/scheduler/poly_lr.py @@ -0,0 +1,116 @@ +""" Polynomial Scheduler + +Polynomial LR schedule with warmup, noise. + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math +import logging + +import torch + +from .scheduler import Scheduler + + +_logger = logging.getLogger(__name__) + + +class PolyLRScheduler(Scheduler): + """ Polynomial LR Scheduler w/ warmup, noise, and k-decay + + k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 + """ + + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + power: float = 0.5, + lr_min: float = 0., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, + warmup_t=0, + warmup_lr_init=0, + warmup_prefix=False, + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + k_decay=.5, + initialize=True) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + assert t_initial > 0 + assert lr_min >= 0 + if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: + _logger.warning("Cosine annealing scheduler will have no effect on the learning " + "rate since t_initial = t_mul = eta_mul = 1.") + self.t_initial = t_initial + self.power = power + self.lr_min = lr_min + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay + self.cycle_limit = cycle_limit + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + self.k_decay = k_decay + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + if self.warmup_prefix: + t = t - self.warmup_t + + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + else: + i = t // self.t_initial + t_i = self.t_initial + t_curr = t - (self.t_initial * i) + + gamma = self.cycle_decay ** i + lr_max_values = [v * gamma for v in self.base_values] + k = self.k_decay + + if i < self.cycle_limit: + lrs = [ + self.lr_min + (lr_max - self.lr_min) * (1 - t_curr ** k / t_i ** k) ** self.power + for lr_max in lr_max_values + ] + else: + lrs = [self.lr_min for _ in self.base_values] + + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None + + def get_cycle_length(self, cycles=0): + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: + return self.t_initial * cycles + else: + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index 51b65e00..72a979c2 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -1,11 +1,12 @@ """ Scheduler Factory -Hacked together by / Copyright 2020 Ross Wightman +Hacked together by / Copyright 2021 Ross Wightman """ from .cosine_lr import CosineLRScheduler -from .tanh_lr import TanhLRScheduler -from .step_lr import StepLRScheduler -from .plateau_lr import PlateauLRScheduler from .multistep_lr import MultiStepLRScheduler +from .plateau_lr import PlateauLRScheduler +from .poly_lr import PolyLRScheduler +from .step_lr import StepLRScheduler +from .tanh_lr import TanhLRScheduler def create_scheduler(args, optimizer): @@ -27,19 +28,22 @@ def create_scheduler(args, optimizer): noise_std=getattr(args, 'lr_noise_std', 1.), noise_seed=getattr(args, 'seed', 42), ) + cycle_args = dict( + cycle_mul=getattr(args, 'lr_cycle_mul', 1.), + cycle_decay=getattr(args, 'lr_cycle_decay', 0.1), + cycle_limit=getattr(args, 'lr_cycle_limit', 1), + ) lr_scheduler = None if args.sched == 'cosine': lr_scheduler = CosineLRScheduler( optimizer, t_initial=num_epochs, - t_mul=getattr(args, 'lr_cycle_mul', 1.), lr_min=args.min_lr, - decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, - cycle_limit=getattr(args, 'lr_cycle_limit', 1), - t_in_epochs=True, + k_decay=getattr(args, 'lr_k_decay', 1.0), + **cycle_args, **noise_args, ) num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs @@ -47,12 +51,11 @@ def create_scheduler(args, optimizer): lr_scheduler = TanhLRScheduler( optimizer, t_initial=num_epochs, - t_mul=getattr(args, 'lr_cycle_mul', 1.), lr_min=args.min_lr, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, - cycle_limit=getattr(args, 'lr_cycle_limit', 1), t_in_epochs=True, + **cycle_args, **noise_args, ) num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs @@ -87,5 +90,18 @@ def create_scheduler(args, optimizer): cooldown_t=0, **noise_args, ) + elif args.sched == 'poly': + lr_scheduler = PolyLRScheduler( + optimizer, + power=args.decay_rate, # overloading 'decay_rate' as polynomial power + t_initial=num_epochs, + lr_min=args.min_lr, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + k_decay=getattr(args, 'lr_k_decay', 1.0), + **cycle_args, + **noise_args, + ) + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs return lr_scheduler, num_epochs diff --git a/timm/scheduler/tanh_lr.py b/timm/scheduler/tanh_lr.py index 8cc338bb..f2d3c9cd 100644 --- a/timm/scheduler/tanh_lr.py +++ b/timm/scheduler/tanh_lr.py @@ -2,7 +2,7 @@ TanH schedule with warmup, cycle/restarts, noise. -Hacked together by / Copyright 2020 Ross Wightman +Hacked together by / Copyright 2021 Ross Wightman """ import logging import math @@ -24,15 +24,15 @@ class TanhLRScheduler(Scheduler): def __init__(self, optimizer: torch.optim.Optimizer, t_initial: int, - lb: float = -6., - ub: float = 4., - t_mul: float = 1., + lb: float = -7., + ub: float = 3., lr_min: float = 0., - decay_rate: float = 1., + cycle_mul: float = 1., + cycle_decay: float = 1., + cycle_limit: int = 1, warmup_t=0, warmup_lr_init=0, warmup_prefix=False, - cycle_limit=0, t_in_epochs=True, noise_range_t=None, noise_pct=0.67, @@ -53,9 +53,9 @@ class TanhLRScheduler(Scheduler): self.lb = lb self.ub = ub self.t_initial = t_initial - self.t_mul = t_mul self.lr_min = lr_min - self.decay_rate = decay_rate + self.cycle_mul = cycle_mul + self.cycle_decay = cycle_decay self.cycle_limit = cycle_limit self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init @@ -75,27 +75,26 @@ class TanhLRScheduler(Scheduler): if self.warmup_prefix: t = t - self.warmup_t - if self.t_mul != 1: - i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) - t_i = self.t_mul ** i * self.t_initial - t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial + if self.cycle_mul != 1: + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) + t_i = self.cycle_mul ** i * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial else: i = t // self.t_initial t_i = self.t_initial t_curr = t - (self.t_initial * i) - if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): - gamma = self.decay_rate ** i - lr_min = self.lr_min * gamma + if i < self.cycle_limit: + gamma = self.cycle_decay ** i lr_max_values = [v * gamma for v in self.base_values] tr = t_curr / t_i lrs = [ - lr_min + 0.5 * (lr_max - lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) + self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr)) for lr_max in lr_max_values ] else: - lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values] + lrs = [self.lr_min for _ in self.base_values] return lrs def get_epoch_values(self, epoch: int): @@ -111,10 +110,8 @@ class TanhLRScheduler(Scheduler): return None def get_cycle_length(self, cycles=0): - if not cycles: - cycles = self.cycle_limit - cycles = max(1, cycles) - if self.t_mul == 1.0: + cycles = max(1, cycles or self.cycle_limit) + if self.cycle_mul == 1.0: return self.t_initial * cycles else: - return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) + return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))