pull/1239/head
parent
5b9c69e80a
commit
91ab0b6ce5
@ -0,0 +1,219 @@
|
||||
""" Checkpoint Manager
|
||||
|
||||
Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import glob
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
import shutil
|
||||
from typing import Optional, Dict, Callable, List
|
||||
from dataclasses import dataclass, replace
|
||||
|
||||
|
||||
from .checkpoint import save_train_state
|
||||
from .train_state import TrainState
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointInfo:
|
||||
path: str = ''
|
||||
metrics: Dict[str, float] = None # all metrics at time of checkpoint save
|
||||
metric_name: str = 'loss'
|
||||
metric_decreasing: bool = True
|
||||
epoch: int = 0
|
||||
global_step: int = 0
|
||||
|
||||
@property
|
||||
def valid_key(self):
|
||||
return self.metric_name and self.metrics and self.metric_name in self.metrics
|
||||
|
||||
@property
|
||||
def sort_key(self):
|
||||
return self.metrics[self.metric_name] if self.valid_key else self.epoch
|
||||
|
||||
@property
|
||||
def decreasing_key(self):
|
||||
return self.metric_decreasing if self.valid_key else False
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
def __init__(
|
||||
self,
|
||||
hparams=None,
|
||||
save_state_fn=None,
|
||||
checkpoint_dir='',
|
||||
recovery_dir='',
|
||||
checkpoint_tmpl=None,
|
||||
recovery_tmpl=None,
|
||||
metric_name='loss',
|
||||
metric_decreasing=True,
|
||||
max_history=10):
|
||||
|
||||
# extra items to include in checkpoint
|
||||
self.hparams = hparams # train arguments (config / hparams) # FIXME this will change with new config system
|
||||
|
||||
# state
|
||||
self.checkpoint_files: List[CheckpointInfo] = [] # (filename, metric) tuples in order of decreasing betterness
|
||||
self.best_checkpoint = None
|
||||
self.curr_recovery_file = ''
|
||||
self.prev_recovery_file = ''
|
||||
self.can_hardlink = True
|
||||
|
||||
# util / helper fn
|
||||
self.save_state_fn = save_state_fn or save_train_state
|
||||
|
||||
# file / folder config
|
||||
self.extension = '.pth.tar'
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.recovery_dir = recovery_dir
|
||||
self.checkpoint_tmpl = (checkpoint_tmpl or 'checkpoint-{index}') + self.extension
|
||||
self.recovery_tmpl = (recovery_tmpl or 'recovery-{index}') + self.extension
|
||||
|
||||
# ordering / history config
|
||||
self.metric_name = metric_name
|
||||
self.metric_decreasing = metric_decreasing
|
||||
self.metric_cmp_fn = operator.lt if metric_decreasing else operator.gt
|
||||
self.max_history = max_history
|
||||
assert self.max_history >= 1
|
||||
|
||||
def _replace(self, src, dst):
|
||||
if self.can_hardlink:
|
||||
try:
|
||||
if os.path.exists(dst):
|
||||
os.unlink(dst) # required for Windows support.
|
||||
except Exception as e:
|
||||
self.can_hardlink = False
|
||||
os.replace(src, dst)
|
||||
|
||||
def _duplicate(self, src, dst):
|
||||
if self.can_hardlink:
|
||||
try:
|
||||
if os.path.exists(dst):
|
||||
# for Windows
|
||||
os.unlink(dst)
|
||||
os.link(src, dst)
|
||||
return
|
||||
except Exception as e:
|
||||
self.can_hardlink = False
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
def _save(self, save_path, train_state: TrainState, metrics: Optional[Dict[str, float]] = None):
|
||||
extra_state = dict(
|
||||
# version < 2 increments epoch before save
|
||||
# version < 3, pre timm bits
|
||||
# version 3, first timm bits checkpoitns
|
||||
version=3,
|
||||
)
|
||||
if self.hparams is not None:
|
||||
extra_state.update(dict(arch=self.hparams['model'], hparams=self.hparams))
|
||||
else:
|
||||
arch = getattr(train_state.model, 'default_cfg', dict()).get('architecture', None)
|
||||
if arch is None:
|
||||
arch = type(train_state.model).__name__.lower()
|
||||
extra_state.update(dict(arch=arch))
|
||||
if metrics is not None:
|
||||
# save the metrics and how we originally sorted them in the checkpoint for future comparisons
|
||||
extra_state.update(dict(
|
||||
metrics=metrics,
|
||||
metric_name=self.metric_name,
|
||||
metric_decreasing=self.metric_decreasing
|
||||
))
|
||||
|
||||
self.save_state_fn(save_path, train_state, extra_state)
|
||||
|
||||
checkpoint_info = CheckpointInfo(
|
||||
path=save_path,
|
||||
metrics=metrics,
|
||||
metric_name=self.metric_name,
|
||||
metric_decreasing=self.metric_decreasing,
|
||||
epoch=train_state.epoch,
|
||||
global_step=train_state.step_count_global,
|
||||
)
|
||||
return checkpoint_info
|
||||
|
||||
def _udpate_checkpoints(self, info: CheckpointInfo):
|
||||
self.checkpoint_files.append(info)
|
||||
self.checkpoint_files = sorted(
|
||||
self.checkpoint_files,
|
||||
key=lambda x: x.sort_key,
|
||||
reverse=not info.decreasing_key, # sort in descending order if a lower metric is not better
|
||||
)
|
||||
|
||||
def _cleanup_checkpoints(self, trim=0):
|
||||
trim = min(len(self.checkpoint_files), trim)
|
||||
delete_index = self.max_history - trim
|
||||
if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
|
||||
return
|
||||
to_delete = self.checkpoint_files[delete_index:]
|
||||
for d in to_delete:
|
||||
try:
|
||||
_logger.debug("Cleaning checkpoint: {}".format(d))
|
||||
os.remove(d[0])
|
||||
except Exception as e:
|
||||
_logger.error("Exception '{}' while deleting checkpoint".format(e))
|
||||
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
||||
|
||||
def _compare_metric(self, lhs: CheckpointInfo, rhs: CheckpointInfo):
|
||||
# compare metrics against an existing checkpoint
|
||||
if not lhs or not lhs.valid_key or not rhs or not rhs.valid_key:
|
||||
# always assume lhs metrics are better if there are no usable metrics to compare
|
||||
return True
|
||||
return self.metric_cmp_fn(lhs.sort_key, rhs.sort_key)
|
||||
|
||||
def save_checkpoint(self, train_state: TrainState, metrics: Optional[Dict[str, float]] = None):
|
||||
assert train_state.epoch >= 0
|
||||
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
|
||||
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
|
||||
curr_checkpoint = self._save(tmp_save_path, train_state, metrics)
|
||||
self._replace(tmp_save_path, last_save_path)
|
||||
|
||||
worst_checkpoint = self.checkpoint_files[-1] if self.checkpoint_files else None
|
||||
if len(self.checkpoint_files) < self.max_history or self._compare_metric(curr_checkpoint, worst_checkpoint):
|
||||
if len(self.checkpoint_files) >= self.max_history:
|
||||
self._cleanup_checkpoints(1)
|
||||
|
||||
filename = self.checkpoint_tmpl.format(index=train_state.epoch)
|
||||
save_path = os.path.join(self.checkpoint_dir, filename)
|
||||
curr_checkpoint = replace(curr_checkpoint, path=save_path)
|
||||
self._duplicate(last_save_path, save_path)
|
||||
self._udpate_checkpoints(curr_checkpoint)
|
||||
|
||||
checkpoints_str = "Current checkpoints:\n"
|
||||
for c in self.checkpoint_files:
|
||||
checkpoints_str += f' {c.path}, {c.sort_key}\n'.format(c)
|
||||
_logger.info(checkpoints_str)
|
||||
|
||||
if curr_checkpoint.valid_key and self._compare_metric(curr_checkpoint, self.best_checkpoint):
|
||||
self.best_checkpoint = curr_checkpoint
|
||||
best_save_path = os.path.join(self.checkpoint_dir, 'best' + self.extension)
|
||||
self._duplicate(last_save_path, best_save_path)
|
||||
|
||||
return None if self.best_checkpoint is None else curr_checkpoint
|
||||
|
||||
def save_recovery(self, train_state: TrainState):
|
||||
tmp_save_path = os.path.join(self.recovery_dir, 'recovery_tmp' + self.extension)
|
||||
self._save(tmp_save_path, train_state)
|
||||
|
||||
filename = self.recovery_tmpl.format(index=train_state.step_count_global)
|
||||
save_path = os.path.join(self.recovery_dir, filename)
|
||||
self._replace(tmp_save_path, save_path)
|
||||
|
||||
if os.path.exists(self.prev_recovery_file):
|
||||
try:
|
||||
_logger.debug("Cleaning recovery: {}".format(self.prev_recovery_file))
|
||||
os.remove(self.prev_recovery_file)
|
||||
except Exception as e:
|
||||
_logger.error("Exception '{}' while removing {}".format(e, self.prev_recovery_file))
|
||||
self.prev_recovery_file = self.curr_recovery_file
|
||||
self.curr_recovery_file = save_path
|
||||
|
||||
def find_recovery(self):
|
||||
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
|
||||
files = glob.glob(recovery_path + '*' + self.extension)
|
||||
files = sorted(files)
|
||||
return files[0] if len(files) else ''
|
@ -1,13 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .logger import Logger
|
||||
from timm.utils.checkpoint_saver import CheckpointSaver
|
||||
from .monitor import Monitor
|
||||
from .checkpoint_manager import CheckpointManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainServices:
|
||||
""" Train Loop Services
|
||||
"""
|
||||
logger: Logger = None
|
||||
saver: CheckpointSaver = None
|
||||
logger: Monitor = None
|
||||
checkpoint_manager: CheckpointManager = None
|
||||
|
||||
|
Loading…
Reference in new issue