Add basic training resume based on legacy code

pull/1239/head
Ross Wightman 3 years ago
parent 4210d922d2
commit 5b9c69e80a

@ -4,55 +4,63 @@ from collections import OrderedDict
import torch import torch
from timm.utils import unwrap_model
from .train_state import TrainState, serialize_train_state, deserialize_train_state from .train_state import TrainState, serialize_train_state, deserialize_train_state
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def _load_state_dict(checkpoint, state_dict_key='state_dict'):
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
return new_state_dict
def resume_train_checkpoint( def resume_train_checkpoint(
train_state, train_state: TrainState,
checkpoint_path, checkpoint_path,
resume_opt=True, resume_opt=True,
deserialize_fn=deserialize_train_state, deserialize_fn=deserialize_train_state,
log_info=True): log_info=True):
raise NotImplementedError # FIXME this is a hacky adaptation of pre-bits resume to get up and running quickly
resume_epoch = None
# resume_epoch = None if os.path.isfile(checkpoint_path):
# if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu')
# checkpoint = torch.load(checkpoint_path, map_location='cpu') assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint
# if log_info:
# if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: _logger.info('Restoring model state from checkpoint...')
# if log_info:
# _logger.info('Restoring model state from checkpoint...') train_state.model.load_state_dict(_load_state_dict(checkpoint))
# new_state_dict = OrderedDict()
# for k, v in checkpoint['state_dict'].items(): if train_state.model_ema is not None and 'state_dict_ema' in checkpoint:
# name = k[7:] if k.startswith('module') else k if log_info:
# new_state_dict[name] = v _logger.info('Restoring model (EMA) state from checkpoint...')
# model.load_state_dict(new_state_dict) unwrap_model(train_state.model_ema).load_state_dict(_load_state_dict(checkpoint, 'state_dict_ema'))
#
# if optimizer is not None and 'optimizer' in checkpoint: if resume_opt:
# if log_info: if train_state.updater.optimizer is not None and 'optimizer' in checkpoint:
# _logger.info('Restoring optimizer state from checkpoint...') if log_info:
# optimizer.load_state_dict(checkpoint['optimizer']) _logger.info('Restoring optimizer state from checkpoint...')
# train_state.updater.optimizer.load_state_dict(checkpoint['optimizer'])
# if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
# if log_info: scaler_state_dict_key = 'amp_scaler'
# _logger.info('Restoring AMP loss scaler state from checkpoint...') if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint:
# loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) if log_info:
# _logger.info('Restoring AMP loss scaler state from checkpoint...')
# if 'epoch' in checkpoint: train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key])
# resume_epoch = checkpoint['epoch']
# if 'version' in checkpoint and checkpoint['version'] > 1: if 'epoch' in checkpoint:
# resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save resume_epoch = checkpoint['epoch']
# if 'version' in checkpoint and checkpoint['version'] > 1:
# if log_info: resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
# _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only
# else:
# model.load_state_dict(checkpoint) if log_info:
# if log_info: _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
# _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) else:
# return resume_epoch _logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path))
# else: raise FileNotFoundError()
# _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
# raise FileNotFoundError()

@ -340,11 +340,11 @@ def main():
loader_train.mixup_enabled = False loader_train.mixup_enabled = False
train_metrics = train_one_epoch( train_metrics = train_one_epoch(
dev_env=dev_env,
state=train_state, state=train_state,
services=services,
cfg=train_cfg, cfg=train_cfg,
loader=loader_train services=services,
loader=loader_train,
dev_env=dev_env,
) )
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'): if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
@ -356,8 +356,8 @@ def main():
train_state.model, train_state.model,
train_state.eval_loss, train_state.eval_loss,
loader_eval, loader_eval,
dev_env, services.logger,
logger=services.logger) dev_env)
if train_state.model_ema is not None and not args.model_ema_force_cpu: if train_state.model_ema is not None and not args.model_ema_force_cpu:
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'): if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
@ -367,8 +367,8 @@ def main():
train_state.model_ema.module, train_state.model_ema.module,
train_state.eval_loss, train_state.eval_loss,
loader_eval, loader_eval,
services.logger,
dev_env, dev_env,
logger=services.logger,
phase_suffix='EMA') phase_suffix='EMA')
eval_metrics = ema_eval_metrics eval_metrics = ema_eval_metrics
@ -432,6 +432,7 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
clip_value=args.clip_grad, clip_value=args.clip_grad,
model_ema=args.model_ema, model_ema=args.model_ema,
model_ema_decay=args.model_ema_decay, model_ema_decay=args.model_ema_decay,
resume_path=args.resume,
use_syncbn=args.sync_bn, use_syncbn=args.sync_bn,
) )
@ -543,11 +544,11 @@ def setup_data(args, default_cfg, dev_env, mixup_active):
def train_one_epoch( def train_one_epoch(
dev_env: DeviceEnv,
state: TrainState, state: TrainState,
cfg: TrainCfg, cfg: TrainCfg,
services: TrainServices, services: TrainServices,
loader, loader,
dev_env: DeviceEnv,
): ):
tracker = Tracker() tracker = Tracker()
loss_meter = AvgTensor() loss_meter = AvgTensor()
@ -571,10 +572,10 @@ def train_one_epoch(
state.updater.after_step( state.updater.after_step(
after_train_step, after_train_step,
dev_env,
state, state,
services,
cfg, cfg,
services,
dev_env,
step_idx, step_idx,
step_end_idx, step_end_idx,
tracker, tracker,
@ -592,10 +593,10 @@ def train_one_epoch(
def after_train_step( def after_train_step(
dev_env: DeviceEnv,
state: TrainState, state: TrainState,
services: TrainServices,
cfg: TrainCfg, cfg: TrainCfg,
services: TrainServices,
dev_env: DeviceEnv,
step_idx: int, step_idx: int,
step_end_idx: int, step_end_idx: int,
tracker: Tracker, tracker: Tracker,
@ -640,8 +641,8 @@ def evaluate(
model: nn.Module, model: nn.Module,
loss_fn: nn.Module, loss_fn: nn.Module,
loader, loader,
dev_env: DeviceEnv,
logger: Logger, logger: Logger,
dev_env: DeviceEnv,
phase_suffix: str = '', phase_suffix: str = '',
log_interval: int = 10, log_interval: int = 10,
): ):

Loading…
Cancel
Save