import logging import os from collections import OrderedDict from typing import Dict, Any, Callable import torch from timm.utils import unwrap_model from .device_env import DeviceEnv from .train_state import TrainState _logger = logging.getLogger(__name__) def save_train_state( checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS? train_state: TrainState, extra_state: Dict[str, Any] = None, unwrap_fn: Callable = unwrap_model, dev_env: DeviceEnv = None, log_info: bool = True): assert not train_state.updater.deepspeed # DeepSpeed has a fully custom checkpoint saving setup, it is not possible # specify a filename, checkpoints needed to be saved from all ranks, etc # if train_state.updater.deepspeed: # save_train_state_deepspeed(train_state, checkpoint_path) dev_env = dev_env or DeviceEnv.instance() state_dict = train_state.state_dict(unwrap_fn=unwrap_fn) if extra_state: state_dict.update(extra_state) if dev_env.type_xla: # XLA state dict needs to be moved to CPU before save, this is normally done by xm.save state_dict = dev_env.state_dict_to_cpu(state_dict) torch.save(state_dict, checkpoint_path) def load_train_state( train_state: TrainState, checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS unwrap_fn: Callable = None, load_opt: bool = True, dev_env: DeviceEnv = None, log_info: bool = True ): unwrap_fn = unwrap_fn or unwrap_model if not os.path.isfile(checkpoint_path): _logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() if log_info: _logger.info('Restoring training state from checkpoint...') checkpoint = torch.load(checkpoint_path, map_location='cpu') assert isinstance(checkpoint, dict) if not checkpoint.get('version', 0) > 2: load_legacy_checkpoint(train_state, checkpoint=checkpoint, load_opt=load_opt, log_info=log_info) if log_info: _logger.info("Loaded legacy checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch)) return train_state.load_state_dict(checkpoint, unwrap_fn=unwrap_fn) if log_info: _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch)) def _get_state_dict(checkpoint, state_dict_key='state_dict'): new_state_dict = OrderedDict() for k, v in checkpoint[state_dict_key].items(): name = k[7:] if k.startswith('module') else k new_state_dict[name] = v return new_state_dict def load_legacy_checkpoint( train_state: TrainState, checkpoint, load_opt=True, log_info=True): assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint train_state.model.load_state_dict(_get_state_dict(checkpoint)) if train_state.model_ema is not None and 'state_dict_ema' in checkpoint: if log_info: _logger.info('Restoring model (EMA) state from checkpoint...') unwrap_model(train_state.model_ema).load_state_dict(_get_state_dict(checkpoint, 'state_dict_ema')) if load_opt: if train_state.updater.optimizer is not None and 'optimizer' in checkpoint: if log_info: _logger.info('Restoring optimizer state from checkpoint...') train_state.updater.optimizer.load_state_dict(checkpoint['optimizer']) scaler_state_dict_key = 'amp_scaler' if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint: if log_info: _logger.info('Restoring AMP loss scaler state from checkpoint...') train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key]) if 'epoch' in checkpoint: resume_epoch = checkpoint['epoch'] if 'version' in checkpoint and checkpoint['version'] > 1: resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only