From 80ca078aedee177831beeffbf24e0aa48a45909b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 11 Nov 2021 15:09:31 -0800 Subject: [PATCH] Fix a few bugs and formatting/naming issues * Pass optimizer resume flag through to checkpoint / updater restore. Related to #961 but not clear how relates to crash. * Rename monitor step args, cleanup handling of step_end_idx vs num_steps for consistent log output in either case * Resume from proper epoch (ie next epoch relative to checkpoint) --- timm/bits/checkpoint.py | 2 +- timm/bits/monitor.py | 20 +++++++------------- timm/bits/train_state.py | 6 +++--- train.py | 17 +++++++++-------- validate.py | 5 +++-- 5 files changed, 23 insertions(+), 27 deletions(-) diff --git a/timm/bits/checkpoint.py b/timm/bits/checkpoint.py index df21ab5e..b7afd731 100644 --- a/timm/bits/checkpoint.py +++ b/timm/bits/checkpoint.py @@ -62,7 +62,7 @@ def load_train_state( _logger.info("Loaded legacy checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch)) return - train_state.load_state_dict(checkpoint, unwrap_fn=unwrap_fn) + train_state.load_state_dict(checkpoint, unwrap_fn=unwrap_fn, load_opt=load_opt) if log_info: _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch)) diff --git a/timm/bits/monitor.py b/timm/bits/monitor.py index af397e1a..e4dd95f0 100644 --- a/timm/bits/monitor.py +++ b/timm/bits/monitor.py @@ -43,13 +43,6 @@ except ImportError: # f' Data: {data_time.smooth_val:.3f} ({data_time.avg:.3f})' # log_str += f' Loss: {loss.smooth_val:>9.6f} ({loss.avg:>6.4f}) ' # log_str += f' LR: {lr:.3e} ' -# -# if args.save_images and output_dir: -# torchvision.utils.save_image( -# input, -# os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), -# padding=0, -# normalize=True) def summary_row_dict(results, index=None, index_name='epoch'): @@ -159,8 +152,8 @@ class Monitor: def log_step( self, phase: str, - step: int, - step_end: Optional[int] = None, + step_idx: int, + step_end_idx: Optional[int] = None, epoch: Optional[int] = None, loss: Optional[float] = None, rate: Optional[float] = None, @@ -171,14 +164,15 @@ class Monitor: """ if not self.output_enabled: return - + if 'num_steps' in kwargs: + step_end_idx = max(0, kwargs.pop('num_steps') - 1) phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}:' - progress = 100. * step / step_end if step_end else 0. + progress = 100. * step_idx / step_end_idx if step_end_idx else 0. text_update = [ phase_title, f'{epoch}' if epoch is not None else None, - f'[{step}]' if step_end is None else None, - f'[{step}/{step_end} ({progress:>3.0f}%)]' if step_end is not None else None, + f'[{step_idx}]' if step_end_idx is None else None, + f'[{step_idx}/{step_end_idx} ({progress:>3.0f}%)]' if step_end_idx is not None else None, f'Rate: {rate:.2f}/s' if rate is not None else None, f'Loss: {loss:.5f}' if loss is not None else None, ] diff --git a/timm/bits/train_state.py b/timm/bits/train_state.py index 91fcf76f..5d20f500 100644 --- a/timm/bits/train_state.py +++ b/timm/bits/train_state.py @@ -45,13 +45,13 @@ class TrainState: train_cfg=vars(self.train_cfg) ) # FIXME include lr_scheduler state? - state.update(self.updater.state_dict()) # updater (optimizer, scaler,e tc) state added to state + state.update(self.updater.state_dict()) # updater (optimizer, scaler, etc.) state added to state return state def load_state_dict(self, state_dict, unwrap_fn=unwrap_model, load_opt=True): # restore train loop state - self.epoch = state_dict['epoch'] - self.step_count = state_dict['step_count'] + self.epoch = state_dict['epoch'] + 1 + self.step_count = 0 # FIXME need more logic to restore part way through epoch self.step_count_global = state_dict['step_count_global'] # restore model params / state diff --git a/train.py b/train.py index 5c4dab8a..fb6b4319 100755 --- a/train.py +++ b/train.py @@ -452,6 +452,7 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool): model_ema=args.model_ema, model_ema_decay=args.model_ema_decay, resume_path=args.resume, + resume_opt=not args.no_resume_opt, use_syncbn=args.sync_bn, ) @@ -683,7 +684,7 @@ def after_train_step( Returns: """ - end_step = step_idx == step_end_idx + last_step = step_idx == step_end_idx with torch.no_grad(): output, target, loss = tensors @@ -696,15 +697,15 @@ def after_train_step( state = replace(state, step_count_global=state.step_count_global + 1) cfg = state.train_cfg - if services.monitor is not None and end_step or (step_idx + 1) % cfg.log_interval == 0: + if services.monitor is not None and last_step or (step_idx + 1) % cfg.log_interval == 0: global_batch_size = dev_env.world_size * output.shape[0] loss_avg = loss_meter.compute() if services.monitor is not None: lr_avg = state.updater.get_average_lr() services.monitor.log_step( 'Train', - step=step_idx, - step_end=step_end_idx, + step_idx=step_idx, + step_end_idx=step_end_idx, epoch=state.epoch, loss=loss_avg.item(), rate=tracker.get_avg_iter_rate(global_batch_size), @@ -712,8 +713,8 @@ def after_train_step( ) if services.checkpoint is not None and cfg.recovery_interval and ( - end_step or (step_idx + 1) % cfg.recovery_interval == 0): - services.checkpoint.save_recovery(state.epoch, batch_idx=step_idx) + last_step or (step_idx + 1) % cfg.recovery_interval == 0): + services.checkpoint.save_recovery(state) if state.lr_scheduler is not None: # FIXME perform scheduler update here or via updater after_step call? @@ -770,8 +771,8 @@ def evaluate( loss_avg = losses_m.compute() logger.log_step( 'Eval', - step=step_idx, - step_end=end_idx, + step_idx=step_idx, + step_end_idx=end_idx, loss=loss_avg.item(), top1=top1.item(), top5=top5.item(), diff --git a/validate.py b/validate.py index ac1d9eb1..03a90dc0 100755 --- a/validate.py +++ b/validate.py @@ -173,6 +173,7 @@ def validate(args): with torch.no_grad(): tracker.mark_iter() for step_idx, (sample, target) in enumerate(loader): + last_step = step_idx == num_steps - 1 tracker.mark_iter_data_end() # compute output @@ -197,12 +198,12 @@ def validate(args): accuracy.update(output.detach(), target) tracker.mark_iter() - if step_idx % args.log_freq == 0: + if last_step or step_idx % args.log_freq == 0: top1, top5 = accuracy.compute().values() loss_avg = losses.compute() logger.log_step( phase='eval', - step=step_idx, + step_idx=step_idx, num_steps=num_steps, rate=args.batch_size / tracker.iter_time.avg, loss=loss_avg.item(),