Scheduler update, add v2 factory method, support scheduling on updates instead of just epochs. Add LR to summary csv. Add lr_base scaling calculations to train script. Fix #1168

pull/1479/head
Ross Wightman 2 years ago
parent 4f18d6dc5f
commit b1b024dfed

@ -193,7 +193,8 @@ def create_optimizer_v2(
filter_bias_and_bn: bool = True, filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None, layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable] = None, param_group_fn: Optional[Callable] = None,
**kwargs): **kwargs,
):
""" Create an optimizer. """ Create an optimizer.
TODO currently the model is passed in and all parameters are selected for optimization. TODO currently the model is passed in and all parameters are selected for optimization.

@ -5,4 +5,4 @@ from .poly_lr import PolyLRScheduler
from .step_lr import StepLRScheduler from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler from .tanh_lr import TanhLRScheduler
from .scheduler_factory import create_scheduler from .scheduler_factory import create_scheduler, create_scheduler_v2, scheduler_kwargs

@ -26,7 +26,8 @@ class CosineLRScheduler(Scheduler):
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
""" """
def __init__(self, def __init__(
self,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
t_initial: int, t_initial: int,
lr_min: float = 0., lr_min: float = 0.,
@ -42,16 +43,24 @@ class CosineLRScheduler(Scheduler):
noise_std=1.0, noise_std=1.0,
noise_seed=42, noise_seed=42,
k_decay=1.0, k_decay=1.0,
initialize=True) -> None: initialize=True,
) -> None:
super().__init__( super().__init__(
optimizer, param_group_field="lr", optimizer,
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, param_group_field="lr",
initialize=initialize) t_in_epochs=t_in_epochs,
noise_range_t=noise_range_t,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
assert t_initial > 0 assert t_initial > 0
assert lr_min >= 0 assert lr_min >= 0
if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
_logger.warning("Cosine annealing scheduler will have no effect on the learning " _logger.warning(
"Cosine annealing scheduler will have no effect on the learning "
"rate since t_initial = t_mul = eta_mul = 1.") "rate since t_initial = t_mul = eta_mul = 1.")
self.t_initial = t_initial self.t_initial = t_initial
self.lr_min = lr_min self.lr_min = lr_min
@ -61,7 +70,6 @@ class CosineLRScheduler(Scheduler):
self.warmup_t = warmup_t self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
self.k_decay = k_decay self.k_decay = k_decay
if self.warmup_t: if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
@ -99,18 +107,6 @@ class CosineLRScheduler(Scheduler):
return lrs return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0): def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit) cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0: if self.cycle_mul == 1.0:

@ -11,12 +11,14 @@ class MultiStepLRScheduler(Scheduler):
""" """
""" """
def __init__(self, def __init__(
self,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
decay_t: List[int], decay_t: List[int],
decay_rate: float = 1., decay_rate: float = 1.,
warmup_t=0, warmup_t=0,
warmup_lr_init=0, warmup_lr_init=0,
warmup_prefix=True,
t_in_epochs=True, t_in_epochs=True,
noise_range_t=None, noise_range_t=None,
noise_pct=0.67, noise_pct=0.67,
@ -25,15 +27,21 @@ class MultiStepLRScheduler(Scheduler):
initialize=True, initialize=True,
) -> None: ) -> None:
super().__init__( super().__init__(
optimizer, param_group_field="lr", optimizer,
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, param_group_field="lr",
initialize=initialize) t_in_epochs=t_in_epochs,
noise_range_t=noise_range_t,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
self.decay_t = decay_t self.decay_t = decay_t
self.decay_rate = decay_rate self.decay_rate = decay_rate
self.warmup_t = warmup_t self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs self.warmup_prefix = warmup_prefix
if self.warmup_t: if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init) super().update_groups(self.warmup_lr_init)
@ -49,17 +57,7 @@ class MultiStepLRScheduler(Scheduler):
if t < self.warmup_t: if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else: else:
if self.warmup_prefix:
t = t - self.warmup_t
lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values] lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values]
return lrs return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None

@ -12,7 +12,8 @@ from .scheduler import Scheduler
class PlateauLRScheduler(Scheduler): class PlateauLRScheduler(Scheduler):
"""Decay the LR by a factor every time the validation loss plateaus.""" """Decay the LR by a factor every time the validation loss plateaus."""
def __init__(self, def __init__(
self,
optimizer, optimizer,
decay_rate=0.1, decay_rate=0.1,
patience_t=10, patience_t=10,
@ -89,6 +90,9 @@ class PlateauLRScheduler(Scheduler):
if self._is_apply_noise(epoch): if self._is_apply_noise(epoch):
self._apply_noise(epoch) self._apply_noise(epoch)
def step_update(self, num_updates: int, metric: float = None):
return None
def _apply_noise(self, epoch): def _apply_noise(self, epoch):
noise = self._calculate_noise(epoch) noise = self._calculate_noise(epoch)
@ -101,3 +105,6 @@ class PlateauLRScheduler(Scheduler):
new_lr = old_lr + old_lr * noise new_lr = old_lr + old_lr * noise
param_group['lr'] = new_lr param_group['lr'] = new_lr
self.restore_lr = restore_lr self.restore_lr = restore_lr
def _get_lr(self, t: int) -> float:
assert False, 'should not be called as step is overridden'

@ -21,7 +21,8 @@ class PolyLRScheduler(Scheduler):
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
""" """
def __init__(self, def __init__(
self,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
t_initial: int, t_initial: int,
power: float = 0.5, power: float = 0.5,
@ -38,11 +39,18 @@ class PolyLRScheduler(Scheduler):
noise_std=1.0, noise_std=1.0,
noise_seed=42, noise_seed=42,
k_decay=1.0, k_decay=1.0,
initialize=True) -> None: initialize=True,
) -> None:
super().__init__( super().__init__(
optimizer, param_group_field="lr", optimizer,
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, param_group_field="lr",
initialize=initialize) t_in_epochs=t_in_epochs,
noise_range_t=noise_range_t,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize
)
assert t_initial > 0 assert t_initial > 0
assert lr_min >= 0 assert lr_min >= 0
@ -58,7 +66,6 @@ class PolyLRScheduler(Scheduler):
self.warmup_t = warmup_t self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
self.k_decay = k_decay self.k_decay = k_decay
if self.warmup_t: if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
@ -96,18 +103,6 @@ class PolyLRScheduler(Scheduler):
return lrs return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0): def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit) cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0: if self.cycle_mul == 1.0:

@ -1,9 +1,11 @@
from typing import Dict, Any import abc
from abc import ABC
from typing import Any, Dict, Optional
import torch import torch
class Scheduler: class Scheduler(ABC):
""" Parameter Scheduler Base Class """ Parameter Scheduler Base Class
A scheduler base class that can be used to schedule any optimizer parameter groups. A scheduler base class that can be used to schedule any optimizer parameter groups.
@ -22,15 +24,18 @@ class Scheduler:
* https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
""" """
def __init__(self, def __init__(
self,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
param_group_field: str, param_group_field: str,
t_in_epochs: bool = True,
noise_range_t=None, noise_range_t=None,
noise_type='normal', noise_type='normal',
noise_pct=0.67, noise_pct=0.67,
noise_std=1.0, noise_std=1.0,
noise_seed=None, noise_seed=None,
initialize: bool = True) -> None: initialize: bool = True,
) -> None:
self.optimizer = optimizer self.optimizer = optimizer
self.param_group_field = param_group_field self.param_group_field = param_group_field
self._initial_param_group_field = f"initial_{param_group_field}" self._initial_param_group_field = f"initial_{param_group_field}"
@ -45,6 +50,7 @@ class Scheduler:
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
self.metric = None # any point to having this for all? self.metric = None # any point to having this for all?
self.t_in_epochs = t_in_epochs
self.noise_range_t = noise_range_t self.noise_range_t = noise_range_t
self.noise_pct = noise_pct self.noise_pct = noise_pct
self.noise_type = noise_type self.noise_type = noise_type
@ -58,22 +64,26 @@ class Scheduler:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.__dict__.update(state_dict) self.__dict__.update(state_dict)
def get_epoch_values(self, epoch: int): @abc.abstractmethod
return None def _get_lr(self, t: int) -> float:
pass
def get_update_values(self, num_updates: int): def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]:
proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs)
if not proceed:
return None return None
return self._get_lr(t)
def step(self, epoch: int, metric: float = None) -> None: def step(self, epoch: int, metric: float = None) -> None:
self.metric = metric self.metric = metric
values = self.get_epoch_values(epoch) values = self._get_values(epoch, on_epoch=True)
if values is not None: if values is not None:
values = self._add_noise(values, epoch) values = self._add_noise(values, epoch)
self.update_groups(values) self.update_groups(values)
def step_update(self, num_updates: int, metric: float = None): def step_update(self, num_updates: int, metric: float = None):
self.metric = metric self.metric = metric
values = self.get_update_values(num_updates) values = self._get_values(num_updates, on_epoch=False)
if values is not None: if values is not None:
values = self._add_noise(values, num_updates) values = self._add_noise(values, num_updates)
self.update_groups(values) self.update_groups(values)

@ -1,6 +1,10 @@
""" Scheduler Factory """ Scheduler Factory
Hacked together by / Copyright 2021 Ross Wightman Hacked together by / Copyright 2021 Ross Wightman
""" """
from typing import List, Union
from torch.optim import Optimizer
from .cosine_lr import CosineLRScheduler from .cosine_lr import CosineLRScheduler
from .multistep_lr import MultiStepLRScheduler from .multistep_lr import MultiStepLRScheduler
from .plateau_lr import PlateauLRScheduler from .plateau_lr import PlateauLRScheduler
@ -9,99 +13,191 @@ from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler from .tanh_lr import TanhLRScheduler
def create_scheduler(args, optimizer): def scheduler_kwargs(cfg):
num_epochs = args.epochs """ cfg/argparse to kwargs helper
Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
"""
eval_metric = getattr(cfg, 'eval_metric', 'top1')
plateau_mode = 'min' if 'loss' in eval_metric else 'max'
kwargs = dict(
sched=cfg.sched,
num_epochs=getattr(cfg, 'epochs', 100),
decay_epochs=getattr(cfg, 'decay_epochs', 30),
decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]),
warmup_epochs=getattr(cfg, 'warmup_epochs', 5),
cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0),
patience_epochs=getattr(cfg, 'patience_epochs', 10),
decay_rate=getattr(cfg, 'decay_rate', 0.1),
min_lr=getattr(cfg, 'min_lr', 0.),
warmup_lr=getattr(cfg, 'warmup_lr', 1e-5),
warmup_prefix=getattr(cfg, 'warmup_prefix', False),
noise=getattr(cfg, 'lr_noise', None),
noise_pct=getattr(cfg, 'lr_noise_pct', 0.67),
noise_std=getattr(cfg, 'lr_noise_std', 1.),
noise_seed=getattr(cfg, 'seed', 42),
cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.),
cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1),
cycle_limit=getattr(cfg, 'lr_cycle_limit', 1),
k_decay=getattr(cfg, 'lr_k_decay', 1.0),
plateau_mode=plateau_mode,
step_on_epochs=not getattr(cfg, 'sched_on_updates', False),
)
return kwargs
def create_scheduler(
args,
optimizer: Optimizer,
updates_per_epoch: int = 0,
):
return create_scheduler_v2(
optimizer=optimizer,
**scheduler_kwargs(args),
updates_per_epoch=updates_per_epoch,
)
if getattr(args, 'lr_noise', None) is not None: def create_scheduler_v2(
lr_noise = getattr(args, 'lr_noise') optimizer: Optimizer,
if isinstance(lr_noise, (list, tuple)): sched: str = 'cosine',
noise_range = [n * num_epochs for n in lr_noise] num_epochs: int = 300,
decay_epochs: int = 90,
decay_milestones: List[int] = (90, 180, 270),
cooldown_epochs: int = 0,
patience_epochs: int = 10,
decay_rate: float = 0.1,
min_lr: float = 0,
warmup_lr: float = 1e-5,
warmup_epochs: int = 0,
warmup_prefix: bool = False,
noise: Union[float, List[float]] = None,
noise_pct: float = 0.67,
noise_std: float = 1.,
noise_seed: int = 42,
cycle_mul: float = 1.,
cycle_decay: float = 0.1,
cycle_limit: int = 1,
k_decay: float = 1.0,
plateau_mode: str = 'max',
step_on_epochs: bool = True,
updates_per_epoch: int = 0,
):
t_initial = num_epochs
warmup_t = warmup_epochs
decay_t = decay_epochs
cooldown_t = cooldown_epochs
if not step_on_epochs:
assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches'
t_initial = t_initial * updates_per_epoch
warmup_t = warmup_t * updates_per_epoch
decay_t = decay_t * updates_per_epoch
decay_milestones = [d * updates_per_epoch for d in decay_milestones]
cooldown_t = cooldown_t * updates_per_epoch
# warmup args
warmup_args = dict(
warmup_lr_init=warmup_lr,
warmup_t=warmup_t,
warmup_prefix=warmup_prefix,
)
# setup noise args for supporting schedulers
if noise is not None:
if isinstance(noise, (list, tuple)):
noise_range = [n * t_initial for n in noise]
if len(noise_range) == 1: if len(noise_range) == 1:
noise_range = noise_range[0] noise_range = noise_range[0]
else: else:
noise_range = lr_noise * num_epochs noise_range = noise * t_initial
else: else:
noise_range = None noise_range = None
noise_args = dict( noise_args = dict(
noise_range_t=noise_range, noise_range_t=noise_range,
noise_pct=getattr(args, 'lr_noise_pct', 0.67), noise_pct=noise_pct,
noise_std=getattr(args, 'lr_noise_std', 1.), noise_std=noise_std,
noise_seed=getattr(args, 'seed', 42), noise_seed=noise_seed,
) )
# setup cycle args for supporting schedulers
cycle_args = dict( cycle_args = dict(
cycle_mul=getattr(args, 'lr_cycle_mul', 1.), cycle_mul=cycle_mul,
cycle_decay=getattr(args, 'lr_cycle_decay', 0.1), cycle_decay=cycle_decay,
cycle_limit=getattr(args, 'lr_cycle_limit', 1), cycle_limit=cycle_limit,
) )
lr_scheduler = None lr_scheduler = None
if args.sched == 'cosine': if sched == 'cosine':
lr_scheduler = CosineLRScheduler( lr_scheduler = CosineLRScheduler(
optimizer, optimizer,
t_initial=num_epochs, t_initial=t_initial,
lr_min=args.min_lr, lr_min=min_lr,
warmup_lr_init=args.warmup_lr, t_in_epochs=step_on_epochs,
warmup_t=args.warmup_epochs,
k_decay=getattr(args, 'lr_k_decay', 1.0),
**cycle_args, **cycle_args,
**warmup_args,
**noise_args, **noise_args,
k_decay=k_decay,
) )
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs elif sched == 'tanh':
elif args.sched == 'tanh':
lr_scheduler = TanhLRScheduler( lr_scheduler = TanhLRScheduler(
optimizer, optimizer,
t_initial=num_epochs, t_initial=t_initial,
lr_min=args.min_lr, lr_min=min_lr,
warmup_lr_init=args.warmup_lr, t_in_epochs=step_on_epochs,
warmup_t=args.warmup_epochs,
t_in_epochs=True,
**cycle_args, **cycle_args,
**warmup_args,
**noise_args, **noise_args,
) )
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs elif sched == 'step':
elif args.sched == 'step':
lr_scheduler = StepLRScheduler( lr_scheduler = StepLRScheduler(
optimizer, optimizer,
decay_t=args.decay_epochs, decay_t=decay_t,
decay_rate=args.decay_rate, decay_rate=decay_rate,
warmup_lr_init=args.warmup_lr, t_in_epochs=step_on_epochs,
warmup_t=args.warmup_epochs, **warmup_args,
**noise_args, **noise_args,
) )
elif args.sched == 'multistep': elif sched == 'multistep':
lr_scheduler = MultiStepLRScheduler( lr_scheduler = MultiStepLRScheduler(
optimizer, optimizer,
decay_t=args.decay_milestones, decay_t=decay_milestones,
decay_rate=args.decay_rate, decay_rate=decay_rate,
warmup_lr_init=args.warmup_lr, t_in_epochs=step_on_epochs,
warmup_t=args.warmup_epochs, **warmup_args,
**noise_args, **noise_args,
) )
elif args.sched == 'plateau': elif sched == 'plateau':
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' assert step_on_epochs, 'Plateau LR only supports step per epoch.'
warmup_args.pop('warmup_prefix', False)
lr_scheduler = PlateauLRScheduler( lr_scheduler = PlateauLRScheduler(
optimizer, optimizer,
decay_rate=args.decay_rate, decay_rate=decay_rate,
patience_t=args.patience_epochs, patience_t=patience_epochs,
lr_min=args.min_lr,
mode=mode,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cooldown_t=0, cooldown_t=0,
**warmup_args,
lr_min=min_lr,
mode=plateau_mode,
**noise_args, **noise_args,
) )
elif args.sched == 'poly': elif sched == 'poly':
lr_scheduler = PolyLRScheduler( lr_scheduler = PolyLRScheduler(
optimizer, optimizer,
power=args.decay_rate, # overloading 'decay_rate' as polynomial power power=decay_rate, # overloading 'decay_rate' as polynomial power
t_initial=num_epochs, t_initial=t_initial,
lr_min=args.min_lr, lr_min=min_lr,
warmup_lr_init=args.warmup_lr, t_in_epochs=step_on_epochs,
warmup_t=args.warmup_epochs, k_decay=k_decay,
k_decay=getattr(args, 'lr_k_decay', 1.0),
**cycle_args, **cycle_args,
**warmup_args,
**noise_args, **noise_args,
) )
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
if hasattr(lr_scheduler, 'get_cycle_length'):
# for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
if step_on_epochs:
num_epochs = t_with_cycles_and_cooldown
else:
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
return lr_scheduler, num_epochs return lr_scheduler, num_epochs

@ -14,12 +14,14 @@ class StepLRScheduler(Scheduler):
""" """
""" """
def __init__(self, def __init__(
self,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
decay_t: float, decay_t: float,
decay_rate: float = 1., decay_rate: float = 1.,
warmup_t=0, warmup_t=0,
warmup_lr_init=0, warmup_lr_init=0,
warmup_prefix=True,
t_in_epochs=True, t_in_epochs=True,
noise_range_t=None, noise_range_t=None,
noise_pct=0.67, noise_pct=0.67,
@ -28,15 +30,21 @@ class StepLRScheduler(Scheduler):
initialize=True, initialize=True,
) -> None: ) -> None:
super().__init__( super().__init__(
optimizer, param_group_field="lr", optimizer,
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, param_group_field="lr",
initialize=initialize) t_in_epochs=t_in_epochs,
noise_range_t=noise_range_t,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
self.decay_t = decay_t self.decay_t = decay_t
self.decay_rate = decay_rate self.decay_rate = decay_rate
self.warmup_t = warmup_t self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init self.warmup_lr_init = warmup_lr_init
self.t_in_epochs = t_in_epochs self.warmup_prefix = warmup_prefix
if self.warmup_t: if self.warmup_t:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
super().update_groups(self.warmup_lr_init) super().update_groups(self.warmup_lr_init)
@ -47,17 +55,7 @@ class StepLRScheduler(Scheduler):
if t < self.warmup_t: if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else: else:
if self.warmup_prefix:
t = t - self.warmup_t
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
return lrs return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None

@ -21,7 +21,8 @@ class TanhLRScheduler(Scheduler):
This is described in the paper https://arxiv.org/abs/1806.01593 This is described in the paper https://arxiv.org/abs/1806.01593
""" """
def __init__(self, def __init__(
self,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
t_initial: int, t_initial: int,
lb: float = -7., lb: float = -7.,
@ -38,11 +39,18 @@ class TanhLRScheduler(Scheduler):
noise_pct=0.67, noise_pct=0.67,
noise_std=1.0, noise_std=1.0,
noise_seed=42, noise_seed=42,
initialize=True) -> None: initialize=True,
) -> None:
super().__init__( super().__init__(
optimizer, param_group_field="lr", optimizer,
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, param_group_field="lr",
initialize=initialize) t_in_epochs=t_in_epochs,
noise_range_t=noise_range_t,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
assert t_initial > 0 assert t_initial > 0
assert lr_min >= 0 assert lr_min >= 0
@ -60,7 +68,6 @@ class TanhLRScheduler(Scheduler):
self.warmup_t = warmup_t self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init self.warmup_lr_init = warmup_lr_init
self.warmup_prefix = warmup_prefix self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
if self.warmup_t: if self.warmup_t:
t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
@ -97,18 +104,6 @@ class TanhLRScheduler(Scheduler):
lrs = [self.lr_min for _ in self.base_values] lrs = [self.lr_min for _ in self.base_values]
return lrs return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None
def get_cycle_length(self, cycles=0): def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit) cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0: if self.cycle_mul == 1.0:

@ -10,6 +10,7 @@ try:
except ImportError: except ImportError:
pass pass
def get_outdir(path, *paths, inc=False): def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths) outdir = os.path.join(path, *paths)
if not os.path.exists(outdir): if not os.path.exists(outdir):
@ -26,10 +27,20 @@ def get_outdir(path, *paths, inc=False):
return outdir return outdir
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False): def update_summary(
epoch,
train_metrics,
eval_metrics,
filename,
lr=None,
write_header=False,
log_wandb=False,
):
rowd = OrderedDict(epoch=epoch) rowd = OrderedDict(epoch=epoch)
rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
if lr is not None:
rowd['lr'] = lr
if log_wandb: if log_wandb:
wandb.log(rowd) wandb.log(rowd)
with open(filename, mode='a') as cf: with open(filename, mode='a') as cf:

@ -36,7 +36,7 @@ from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntrop
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm
from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler from timm.scheduler import create_scheduler_v2, scheduler_kwargs
from timm.utils import ApexScaler, NativeScaler from timm.utils import ApexScaler, NativeScaler
try: try:
@ -163,10 +163,18 @@ group.add_argument('--layer-decay', type=float, default=None,
# Learning rate schedule parameters # Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters') group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
help='LR scheduler (default: "step"') help='LR scheduler (default: "step"')
group.add_argument('--lr', type=float, default=0.05, metavar='LR', group.add_argument('--sched-on-updates', action='store_true', default=False,
help='learning rate (default: 0.05)') help='Apply LR scheduler step on update instead of epoch end.')
group.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate, overrides lr-base if set (default: None)')
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
help='base learning rate: lr = lr_base * global_batch_size / base_size')
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
help='base learning rate batch size (divisor, default: 256).')
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages') help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
@ -181,23 +189,25 @@ group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit, cycles enabled if > 1') help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0, group.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)') help='learning rate k-decay for cosine/poly (default: 1.0)')
group.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
help='warmup learning rate (default: 0.0001)') help='warmup learning rate (default: 1e-5)')
group.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', group.add_argument('--min-lr', type=float, default=0, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
group.add_argument('--epochs', type=int, default=300, metavar='N', group.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 300)') help='number of epochs to train (default: 300)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N', group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N', group.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES", group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
help='list of decay epoch indices for multistep lr. must be increasing') help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=100, metavar='N', group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
help='epoch interval to decay LR') help='epoch interval to decay LR')
group.add_argument('--warmup-epochs', type=int, default=3, metavar='N', group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports') help='epochs to warmup LR, if scheduler supports')
group.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', group.add_argument('--warmup-prefix', action='store_true', default=False,
help='Exclude warmup period from decay schedule.'),
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends') help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N', group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10') help='patience epochs for Plateau LR scheduler (default: 10')
@ -469,6 +479,20 @@ def main():
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model) model = memory_efficient_fusion(model)
if args.lr is None:
global_batch_size = args.batch_size * args.world_size
batch_ratio = global_batch_size / args.lr_base_size
if not args.lr_base_scale:
on = args.opt.lower()
args.base_scale = 'sqrt' if any([o in on for o in ('ada', 'lamb')]) else 'linear'
if args.lr_base_scale == 'sqrt':
batch_ratio = batch_ratio ** 0.5
args.lr = args.lr_base * batch_ratio
if utils.is_primary(args):
_logger.info(
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
# setup automatic mixed-precision (AMP) loss scaling and op casting # setup automatic mixed-precision (AMP) loss scaling and op casting
@ -523,20 +547,6 @@ def main():
model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb) model = NativeDDP(model, device_ids=[device], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP # NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if utils.is_primary(args):
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# create the train and eval datasets # create the train and eval datasets
dataset_train = create_dataset( dataset_train = create_dataset(
args.dataset, args.dataset,
@ -691,6 +701,29 @@ def main():
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text) f.write(args_text)
# setup learning rate schedule and starting epoch
updates_per_epoch = len(loader_train)
lr_scheduler, num_epochs = create_scheduler_v2(
optimizer,
**scheduler_kwargs(args),
updates_per_epoch=updates_per_epoch,
)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
if args.step_on_updates:
lr_scheduler.step_update(start_epoch * updates_per_epoch)
else:
lr_scheduler.step(start_epoch)
if utils.is_primary(args):
_logger.info(
f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
try: try:
for epoch in range(start_epoch, num_epochs): for epoch in range(start_epoch, num_epochs):
if hasattr(dataset_train, 'set_epoch'): if hasattr(dataset_train, 'set_epoch'):
@ -741,16 +774,14 @@ def main():
) )
eval_metrics = ema_eval_metrics eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
if output_dir is not None: if output_dir is not None:
lrs = [param_group['lr'] for param_group in optimizer.param_groups]
utils.update_summary( utils.update_summary(
epoch, epoch,
train_metrics, train_metrics,
eval_metrics, eval_metrics,
os.path.join(output_dir, 'summary.csv'), filename=os.path.join(output_dir, 'summary.csv'),
lr=sum(lrs) / len(lrs),
write_header=best_metric is None, write_header=best_metric is None,
log_wandb=args.log_wandb and has_wandb, log_wandb=args.log_wandb and has_wandb,
) )
@ -760,8 +791,13 @@ def main():
save_metric = eval_metrics[eval_metric] save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
if best_metric is not None: if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
@ -796,8 +832,9 @@ def train_one_epoch(
model.train() model.train()
end = time.time() end = time.time()
last_idx = len(loader) - 1 num_batches_per_epoch = len(loader)
num_updates = epoch * len(loader) last_idx = num_batches_per_epoch - 1
num_updates = epoch * num_batches_per_epoch
for batch_idx, (input, target) in enumerate(loader): for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end) data_time_m.update(time.time() - end)

Loading…
Cancel
Save