|
|
|
from typing import Dict, Any
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class Scheduler:
|
|
|
|
""" Parameter Scheduler Base Class
|
|
|
|
A scheduler base class that can be used to schedule any optimizer parameter groups.
|
|
|
|
|
|
|
|
Unlike the builtin PyTorch schedulers, this is intended to be consistently called
|
|
|
|
* At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
|
|
|
|
* At the END of each optimizer update, after incrementing the update count, to calculate next update's value
|
|
|
|
|
|
|
|
The schedulers built on this should try to remain as stateless as possible (for simplicity).
|
|
|
|
|
|
|
|
This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
|
|
|
|
and -1 values for special behaviour. All epoch and update counts must be tracked in the training
|
|
|
|
code and explicitly passed in to the schedulers on the corresponding step or step_update call.
|
|
|
|
|
|
|
|
Based on ideas from:
|
|
|
|
* https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
|
|
|
|
* https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
optimizer: torch.optim.Optimizer,
|
|
|
|
param_group_field: str,
|
|
|
|
initialize: bool = True) -> None:
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.param_group_field = param_group_field
|
|
|
|
self._initial_param_group_field = f"initial_{param_group_field}"
|
|
|
|
if initialize:
|
|
|
|
for i, group in enumerate(self.optimizer.param_groups):
|
|
|
|
if param_group_field not in group:
|
|
|
|
raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
|
|
|
|
group.setdefault(self._initial_param_group_field, group[param_group_field])
|
|
|
|
else:
|
|
|
|
for i, group in enumerate(self.optimizer.param_groups):
|
|
|
|
if self._initial_param_group_field not in group:
|
|
|
|
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
|
|
|
|
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
|
|
|
|
self.metric = None # any point to having this for all?
|
|
|
|
self.update_groups(self.base_values)
|
|
|
|
|
|
|
|
def state_dict(self) -> Dict[str, Any]:
|
|
|
|
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
|
|
|
|
|
|
|
|
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
|
|
|
self.__dict__.update(state_dict)
|
|
|
|
|
|
|
|
def get_epoch_values(self, epoch: int):
|
|
|
|
return None
|
|
|
|
|
|
|
|
def get_update_values(self, num_updates: int):
|
|
|
|
return None
|
|
|
|
|
|
|
|
def step(self, epoch: int, metric: float = None) -> None:
|
|
|
|
self.metric = metric
|
|
|
|
values = self.get_epoch_values(epoch + 1) # +1 to calculate for next epoch
|
|
|
|
if values is not None:
|
|
|
|
self.update_groups(values)
|
|
|
|
|
|
|
|
def step_update(self, num_updates: int, metric: float = None):
|
|
|
|
self.metric = metric
|
|
|
|
values = self.get_update_values(num_updates)
|
|
|
|
if values is not None:
|
|
|
|
self.update_groups(values)
|
|
|
|
|
|
|
|
def update_groups(self, values):
|
|
|
|
if not isinstance(values, (list, tuple)):
|
|
|
|
values = [values] * len(self.optimizer.param_groups)
|
|
|
|
for param_group, value in zip(self.optimizer.param_groups, values):
|
|
|
|
param_group[self.param_group_field] = value
|