Add basic training resume based on legacy code

pull/1239/head
Ross Wightman 3 years ago
parent 4210d922d2
commit 5b9c69e80a

@ -4,55 +4,63 @@ from collections import OrderedDict
import torch
from timm.utils import unwrap_model
from .train_state import TrainState, serialize_train_state, deserialize_train_state
_logger = logging.getLogger(__name__)
def _load_state_dict(checkpoint, state_dict_key='state_dict'):
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
return new_state_dict
def resume_train_checkpoint(
train_state,
train_state: TrainState,
checkpoint_path,
resume_opt=True,
deserialize_fn=deserialize_train_state,
log_info=True):
raise NotImplementedError
# resume_epoch = None
# if os.path.isfile(checkpoint_path):
# checkpoint = torch.load(checkpoint_path, map_location='cpu')
#
# if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
# if log_info:
# _logger.info('Restoring model state from checkpoint...')
# new_state_dict = OrderedDict()
# for k, v in checkpoint['state_dict'].items():
# name = k[7:] if k.startswith('module') else k
# new_state_dict[name] = v
# model.load_state_dict(new_state_dict)
#
# if optimizer is not None and 'optimizer' in checkpoint:
# if log_info:
# _logger.info('Restoring optimizer state from checkpoint...')
# optimizer.load_state_dict(checkpoint['optimizer'])
#
# if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
# if log_info:
# _logger.info('Restoring AMP loss scaler state from checkpoint...')
# loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
#
# if 'epoch' in checkpoint:
# resume_epoch = checkpoint['epoch']
# if 'version' in checkpoint and checkpoint['version'] > 1:
# resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
#
# if log_info:
# _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
# else:
# model.load_state_dict(checkpoint)
# if log_info:
# _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
# return resume_epoch
# else:
# _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
# raise FileNotFoundError()
# FIXME this is a hacky adaptation of pre-bits resume to get up and running quickly
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint
if log_info:
_logger.info('Restoring model state from checkpoint...')
train_state.model.load_state_dict(_load_state_dict(checkpoint))
if train_state.model_ema is not None and 'state_dict_ema' in checkpoint:
if log_info:
_logger.info('Restoring model (EMA) state from checkpoint...')
unwrap_model(train_state.model_ema).load_state_dict(_load_state_dict(checkpoint, 'state_dict_ema'))
if resume_opt:
if train_state.updater.optimizer is not None and 'optimizer' in checkpoint:
if log_info:
_logger.info('Restoring optimizer state from checkpoint...')
train_state.updater.optimizer.load_state_dict(checkpoint['optimizer'])
scaler_state_dict_key = 'amp_scaler'
if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint:
if log_info:
_logger.info('Restoring AMP loss scaler state from checkpoint...')
train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key])
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
_logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()

@ -340,11 +340,11 @@ def main():
loader_train.mixup_enabled = False
train_metrics = train_one_epoch(
dev_env=dev_env,
state=train_state,
services=services,
cfg=train_cfg,
loader=loader_train
services=services,
loader=loader_train,
dev_env=dev_env,
)
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
@ -356,8 +356,8 @@ def main():
train_state.model,
train_state.eval_loss,
loader_eval,
dev_env,
logger=services.logger)
services.logger,
dev_env)
if train_state.model_ema is not None and not args.model_ema_force_cpu:
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
@ -367,8 +367,8 @@ def main():
train_state.model_ema.module,
train_state.eval_loss,
loader_eval,
services.logger,
dev_env,
logger=services.logger,
phase_suffix='EMA')
eval_metrics = ema_eval_metrics
@ -432,6 +432,7 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
clip_value=args.clip_grad,
model_ema=args.model_ema,
model_ema_decay=args.model_ema_decay,
resume_path=args.resume,
use_syncbn=args.sync_bn,
)
@ -543,11 +544,11 @@ def setup_data(args, default_cfg, dev_env, mixup_active):
def train_one_epoch(
dev_env: DeviceEnv,
state: TrainState,
cfg: TrainCfg,
services: TrainServices,
loader,
dev_env: DeviceEnv,
):
tracker = Tracker()
loss_meter = AvgTensor()
@ -571,10 +572,10 @@ def train_one_epoch(
state.updater.after_step(
after_train_step,
dev_env,
state,
services,
cfg,
services,
dev_env,
step_idx,
step_end_idx,
tracker,
@ -592,10 +593,10 @@ def train_one_epoch(
def after_train_step(
dev_env: DeviceEnv,
state: TrainState,
services: TrainServices,
cfg: TrainCfg,
services: TrainServices,
dev_env: DeviceEnv,
step_idx: int,
step_end_idx: int,
tracker: Tracker,
@ -640,8 +641,8 @@ def evaluate(
model: nn.Module,
loss_fn: nn.Module,
loader,
dev_env: DeviceEnv,
logger: Logger,
dev_env: DeviceEnv,
phase_suffix: str = '',
log_interval: int = 10,
):

Loading…
Cancel
Save