Fix a few bugs and formatting/naming issues

* Pass optimizer resume flag through to checkpoint / updater restore. Related to #961 but not clear how relates to crash.
* Rename monitor step args, cleanup handling of step_end_idx vs num_steps for consistent log output in either case
* Resume from proper epoch (ie next epoch relative to checkpoint)
pull/1239/head
Ross Wightman 3 years ago
parent 406c486ba2
commit 80ca078aed

@ -62,7 +62,7 @@ def load_train_state(
_logger.info("Loaded legacy checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch))
return
train_state.load_state_dict(checkpoint, unwrap_fn=unwrap_fn)
train_state.load_state_dict(checkpoint, unwrap_fn=unwrap_fn, load_opt=load_opt)
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch))

@ -43,13 +43,6 @@ except ImportError:
# f' Data: {data_time.smooth_val:.3f} ({data_time.avg:.3f})'
# log_str += f' Loss: {loss.smooth_val:>9.6f} ({loss.avg:>6.4f}) '
# log_str += f' LR: {lr:.3e} '
#
# if args.save_images and output_dir:
# torchvision.utils.save_image(
# input,
# os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
# padding=0,
# normalize=True)
def summary_row_dict(results, index=None, index_name='epoch'):
@ -159,8 +152,8 @@ class Monitor:
def log_step(
self,
phase: str,
step: int,
step_end: Optional[int] = None,
step_idx: int,
step_end_idx: Optional[int] = None,
epoch: Optional[int] = None,
loss: Optional[float] = None,
rate: Optional[float] = None,
@ -171,14 +164,15 @@ class Monitor:
"""
if not self.output_enabled:
return
if 'num_steps' in kwargs:
step_end_idx = max(0, kwargs.pop('num_steps') - 1)
phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}:'
progress = 100. * step / step_end if step_end else 0.
progress = 100. * step_idx / step_end_idx if step_end_idx else 0.
text_update = [
phase_title,
f'{epoch}' if epoch is not None else None,
f'[{step}]' if step_end is None else None,
f'[{step}/{step_end} ({progress:>3.0f}%)]' if step_end is not None else None,
f'[{step_idx}]' if step_end_idx is None else None,
f'[{step_idx}/{step_end_idx} ({progress:>3.0f}%)]' if step_end_idx is not None else None,
f'Rate: {rate:.2f}/s' if rate is not None else None,
f'Loss: {loss:.5f}' if loss is not None else None,
]

@ -45,13 +45,13 @@ class TrainState:
train_cfg=vars(self.train_cfg)
)
# FIXME include lr_scheduler state?
state.update(self.updater.state_dict()) # updater (optimizer, scaler,e tc) state added to state
state.update(self.updater.state_dict()) # updater (optimizer, scaler, etc.) state added to state
return state
def load_state_dict(self, state_dict, unwrap_fn=unwrap_model, load_opt=True):
# restore train loop state
self.epoch = state_dict['epoch']
self.step_count = state_dict['step_count']
self.epoch = state_dict['epoch'] + 1
self.step_count = 0 # FIXME need more logic to restore part way through epoch
self.step_count_global = state_dict['step_count_global']
# restore model params / state

@ -452,6 +452,7 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
model_ema=args.model_ema,
model_ema_decay=args.model_ema_decay,
resume_path=args.resume,
resume_opt=not args.no_resume_opt,
use_syncbn=args.sync_bn,
)
@ -683,7 +684,7 @@ def after_train_step(
Returns:
"""
end_step = step_idx == step_end_idx
last_step = step_idx == step_end_idx
with torch.no_grad():
output, target, loss = tensors
@ -696,15 +697,15 @@ def after_train_step(
state = replace(state, step_count_global=state.step_count_global + 1)
cfg = state.train_cfg
if services.monitor is not None and end_step or (step_idx + 1) % cfg.log_interval == 0:
if services.monitor is not None and last_step or (step_idx + 1) % cfg.log_interval == 0:
global_batch_size = dev_env.world_size * output.shape[0]
loss_avg = loss_meter.compute()
if services.monitor is not None:
lr_avg = state.updater.get_average_lr()
services.monitor.log_step(
'Train',
step=step_idx,
step_end=step_end_idx,
step_idx=step_idx,
step_end_idx=step_end_idx,
epoch=state.epoch,
loss=loss_avg.item(),
rate=tracker.get_avg_iter_rate(global_batch_size),
@ -712,8 +713,8 @@ def after_train_step(
)
if services.checkpoint is not None and cfg.recovery_interval and (
end_step or (step_idx + 1) % cfg.recovery_interval == 0):
services.checkpoint.save_recovery(state.epoch, batch_idx=step_idx)
last_step or (step_idx + 1) % cfg.recovery_interval == 0):
services.checkpoint.save_recovery(state)
if state.lr_scheduler is not None:
# FIXME perform scheduler update here or via updater after_step call?
@ -770,8 +771,8 @@ def evaluate(
loss_avg = losses_m.compute()
logger.log_step(
'Eval',
step=step_idx,
step_end=end_idx,
step_idx=step_idx,
step_end_idx=end_idx,
loss=loss_avg.item(),
top1=top1.item(),
top5=top5.item(),

@ -173,6 +173,7 @@ def validate(args):
with torch.no_grad():
tracker.mark_iter()
for step_idx, (sample, target) in enumerate(loader):
last_step = step_idx == num_steps - 1
tracker.mark_iter_data_end()
# compute output
@ -197,12 +198,12 @@ def validate(args):
accuracy.update(output.detach(), target)
tracker.mark_iter()
if step_idx % args.log_freq == 0:
if last_step or step_idx % args.log_freq == 0:
top1, top5 = accuracy.compute().values()
loss_avg = losses.compute()
logger.log_step(
phase='eval',
step=step_idx,
step_idx=step_idx,
num_steps=num_steps,
rate=args.batch_size / tracker.iter_time.avg,
loss=loss_avg.item(),

Loading…
Cancel
Save