You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
220 lines
8.5 KiB
220 lines
8.5 KiB
""" Checkpoint Manager
|
|
|
|
Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
|
|
|
|
Hacked together by / Copyright 2021 Ross Wightman
|
|
"""
|
|
import glob
|
|
import logging
|
|
import operator
|
|
import os
|
|
import shutil
|
|
from typing import Optional, Dict, Callable, List
|
|
from dataclasses import dataclass, replace
|
|
|
|
|
|
from .checkpoint import save_train_state
|
|
from .train_state import TrainState
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class CheckpointInfo:
|
|
path: str = ''
|
|
metrics: Dict[str, float] = None # all metrics at time of checkpoint save
|
|
metric_name: str = 'loss'
|
|
metric_decreasing: bool = True
|
|
epoch: int = 0
|
|
global_step: int = 0
|
|
|
|
@property
|
|
def valid_key(self):
|
|
return self.metric_name and self.metrics and self.metric_name in self.metrics
|
|
|
|
@property
|
|
def sort_key(self):
|
|
return self.metrics[self.metric_name] if self.valid_key else self.epoch
|
|
|
|
@property
|
|
def decreasing_key(self):
|
|
return self.metric_decreasing if self.valid_key else False
|
|
|
|
|
|
class CheckpointManager:
|
|
def __init__(
|
|
self,
|
|
hparams=None,
|
|
save_state_fn=None,
|
|
checkpoint_dir='',
|
|
recovery_dir='',
|
|
checkpoint_tmpl=None,
|
|
recovery_tmpl=None,
|
|
metric_name='loss',
|
|
metric_decreasing=True,
|
|
max_history=10):
|
|
|
|
# extra items to include in checkpoint
|
|
self.hparams = hparams # train arguments (config / hparams) # FIXME this will change with new config system
|
|
|
|
# state
|
|
self.checkpoint_files: List[CheckpointInfo] = [] # (filename, metric) tuples in order of decreasing betterness
|
|
self.best_checkpoint = None
|
|
self.curr_recovery_file = ''
|
|
self.prev_recovery_file = ''
|
|
self.can_hardlink = True
|
|
|
|
# util / helper fn
|
|
self.save_state_fn = save_state_fn or save_train_state
|
|
|
|
# file / folder config
|
|
self.extension = '.pth.tar'
|
|
self.checkpoint_dir = checkpoint_dir
|
|
self.recovery_dir = recovery_dir
|
|
self.checkpoint_tmpl = (checkpoint_tmpl or 'checkpoint-{index}') + self.extension
|
|
self.recovery_tmpl = (recovery_tmpl or 'recovery-{index}') + self.extension
|
|
|
|
# ordering / history config
|
|
self.metric_name = metric_name
|
|
self.metric_decreasing = metric_decreasing
|
|
self.metric_cmp_fn = operator.lt if metric_decreasing else operator.gt
|
|
self.max_history = max_history
|
|
assert self.max_history >= 1
|
|
|
|
def _replace(self, src, dst):
|
|
if self.can_hardlink:
|
|
try:
|
|
if os.path.exists(dst):
|
|
os.unlink(dst) # required for Windows support.
|
|
except Exception as e:
|
|
self.can_hardlink = False
|
|
os.replace(src, dst)
|
|
|
|
def _duplicate(self, src, dst):
|
|
if self.can_hardlink:
|
|
try:
|
|
if os.path.exists(dst):
|
|
# for Windows
|
|
os.unlink(dst)
|
|
os.link(src, dst)
|
|
return
|
|
except Exception as e:
|
|
self.can_hardlink = False
|
|
shutil.copy2(src, dst)
|
|
|
|
def _save(self, save_path, train_state: TrainState, metrics: Optional[Dict[str, float]] = None):
|
|
extra_state = dict(
|
|
# version < 2 increments epoch before save
|
|
# version < 3, pre timm bits
|
|
# version 3, first timm bits checkpoitns
|
|
version=3,
|
|
)
|
|
if self.hparams is not None:
|
|
extra_state.update(dict(arch=self.hparams['model'], hparams=self.hparams))
|
|
else:
|
|
arch = getattr(train_state.model, 'default_cfg', dict()).get('architecture', None)
|
|
if arch is None:
|
|
arch = type(train_state.model).__name__.lower()
|
|
extra_state.update(dict(arch=arch))
|
|
if metrics is not None:
|
|
# save the metrics and how we originally sorted them in the checkpoint for future comparisons
|
|
extra_state.update(dict(
|
|
metrics=metrics,
|
|
metric_name=self.metric_name,
|
|
metric_decreasing=self.metric_decreasing
|
|
))
|
|
|
|
self.save_state_fn(save_path, train_state, extra_state)
|
|
|
|
checkpoint_info = CheckpointInfo(
|
|
path=save_path,
|
|
metrics=metrics,
|
|
metric_name=self.metric_name,
|
|
metric_decreasing=self.metric_decreasing,
|
|
epoch=train_state.epoch,
|
|
global_step=train_state.step_count_global,
|
|
)
|
|
return checkpoint_info
|
|
|
|
def _udpate_checkpoints(self, info: CheckpointInfo):
|
|
self.checkpoint_files.append(info)
|
|
self.checkpoint_files = sorted(
|
|
self.checkpoint_files,
|
|
key=lambda x: x.sort_key,
|
|
reverse=not info.decreasing_key, # sort in descending order if a lower metric is not better
|
|
)
|
|
|
|
def _cleanup_checkpoints(self, trim=0):
|
|
trim = min(len(self.checkpoint_files), trim)
|
|
delete_index = self.max_history - trim
|
|
if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
|
|
return
|
|
to_delete = self.checkpoint_files[delete_index:]
|
|
for d in to_delete:
|
|
try:
|
|
_logger.debug("Cleaning checkpoint: {}".format(d))
|
|
os.remove(d[0])
|
|
except Exception as e:
|
|
_logger.error("Exception '{}' while deleting checkpoint".format(e))
|
|
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
|
|
|
def _compare_metric(self, lhs: CheckpointInfo, rhs: CheckpointInfo):
|
|
# compare metrics against an existing checkpoint
|
|
if not lhs or not lhs.valid_key or not rhs or not rhs.valid_key:
|
|
# always assume lhs metrics are better if there are no usable metrics to compare
|
|
return True
|
|
return self.metric_cmp_fn(lhs.sort_key, rhs.sort_key)
|
|
|
|
def save_checkpoint(self, train_state: TrainState, metrics: Optional[Dict[str, float]] = None):
|
|
assert train_state.epoch >= 0
|
|
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
|
|
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
|
|
curr_checkpoint = self._save(tmp_save_path, train_state, metrics)
|
|
self._replace(tmp_save_path, last_save_path)
|
|
|
|
worst_checkpoint = self.checkpoint_files[-1] if self.checkpoint_files else None
|
|
if len(self.checkpoint_files) < self.max_history or self._compare_metric(curr_checkpoint, worst_checkpoint):
|
|
if len(self.checkpoint_files) >= self.max_history:
|
|
self._cleanup_checkpoints(1)
|
|
|
|
filename = self.checkpoint_tmpl.format(index=train_state.epoch)
|
|
save_path = os.path.join(self.checkpoint_dir, filename)
|
|
curr_checkpoint = replace(curr_checkpoint, path=save_path)
|
|
self._duplicate(last_save_path, save_path)
|
|
self._udpate_checkpoints(curr_checkpoint)
|
|
|
|
checkpoints_str = "Current checkpoints:\n"
|
|
for c in self.checkpoint_files:
|
|
checkpoints_str += f' {c.path}, {c.sort_key}\n'.format(c)
|
|
_logger.info(checkpoints_str)
|
|
|
|
if curr_checkpoint.valid_key and self._compare_metric(curr_checkpoint, self.best_checkpoint):
|
|
self.best_checkpoint = curr_checkpoint
|
|
best_save_path = os.path.join(self.checkpoint_dir, 'best' + self.extension)
|
|
self._duplicate(last_save_path, best_save_path)
|
|
|
|
return None if self.best_checkpoint is None else curr_checkpoint
|
|
|
|
def save_recovery(self, train_state: TrainState):
|
|
tmp_save_path = os.path.join(self.recovery_dir, 'recovery_tmp' + self.extension)
|
|
self._save(tmp_save_path, train_state)
|
|
|
|
filename = self.recovery_tmpl.format(index=train_state.step_count_global)
|
|
save_path = os.path.join(self.recovery_dir, filename)
|
|
self._replace(tmp_save_path, save_path)
|
|
|
|
if os.path.exists(self.prev_recovery_file):
|
|
try:
|
|
_logger.debug("Cleaning recovery: {}".format(self.prev_recovery_file))
|
|
os.remove(self.prev_recovery_file)
|
|
except Exception as e:
|
|
_logger.error("Exception '{}' while removing {}".format(e, self.prev_recovery_file))
|
|
self.prev_recovery_file = self.curr_recovery_file
|
|
self.curr_recovery_file = save_path
|
|
|
|
def find_recovery(self):
|
|
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
|
|
files = glob.glob(recovery_path + '*' + self.extension)
|
|
files = sorted(files)
|
|
return files[0] if len(files) else ''
|