You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/bits/checkpoint.py

110 lines
4.2 KiB

import logging
import os
from collections import OrderedDict
from typing import Dict, Any, Callable
import torch
from timm.utils import unwrap_model
from .device_env import DeviceEnv
from .train_state import TrainState
_logger = logging.getLogger(__name__)
def save_train_state(
checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS?
train_state: TrainState,
extra_state: Dict[str, Any] = None,
unwrap_fn: Callable = unwrap_model,
dev_env: DeviceEnv = None,
log_info: bool = True):
assert not train_state.updater.deepspeed
# DeepSpeed has a fully custom checkpoint saving setup, it is not possible
# specify a filename, checkpoints needed to be saved from all ranks, etc
# if train_state.updater.deepspeed:
# save_train_state_deepspeed(train_state, checkpoint_path)
dev_env = dev_env or DeviceEnv.instance()
state_dict = train_state.state_dict(unwrap_fn=unwrap_fn)
if extra_state:
state_dict.update(extra_state)
if dev_env.type_xla:
# XLA state dict needs to be moved to CPU before save, this is normally done by xm.save
state_dict = dev_env.state_dict_to_cpu(state_dict)
torch.save(state_dict, checkpoint_path)
def load_train_state(
train_state: TrainState,
checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS
unwrap_fn: Callable = None,
load_opt: bool = True,
dev_env: DeviceEnv = None,
log_info: bool = True
):
unwrap_fn = unwrap_fn or unwrap_model
if not os.path.isfile(checkpoint_path):
_logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
if log_info:
_logger.info('Restoring training state from checkpoint...')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
assert isinstance(checkpoint, dict)
if not checkpoint.get('version', 0) > 2:
load_legacy_checkpoint(train_state, checkpoint=checkpoint, load_opt=load_opt, log_info=log_info)
if log_info:
_logger.info("Loaded legacy checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch))
return
train_state.load_state_dict(checkpoint, unwrap_fn=unwrap_fn, load_opt=load_opt)
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch))
def _get_state_dict(checkpoint, state_dict_key='state_dict'):
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
return new_state_dict
def load_legacy_checkpoint(
train_state: TrainState,
checkpoint,
load_opt=True,
log_info=True):
assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint
train_state.model.load_state_dict(_get_state_dict(checkpoint))
if train_state.model_ema is not None and 'state_dict_ema' in checkpoint:
if log_info:
_logger.info('Restoring model (EMA) state from checkpoint...')
unwrap_model(train_state.model_ema).load_state_dict(_get_state_dict(checkpoint, 'state_dict_ema'))
if load_opt:
if train_state.updater.optimizer is not None and 'optimizer' in checkpoint:
if log_info:
_logger.info('Restoring optimizer state from checkpoint...')
train_state.updater.optimizer.load_state_dict(checkpoint['optimizer'])
scaler_state_dict_key = 'amp_scaler'
if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint:
if log_info:
_logger.info('Restoring AMP loss scaler state from checkpoint...')
train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key])
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only