From 5f563ca4df0ed101ee0f5a7966e7a28238f4d79c Mon Sep 17 00:00:00 2001 From: datamining99 Date: Sat, 22 Aug 2020 11:31:50 +0900 Subject: [PATCH] fix save_checkpoint bug with native amp --- train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index cdb5a95b..a99d5d36 100755 --- a/train.py +++ b/train.py @@ -544,7 +544,7 @@ def main(): save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, - epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp) + epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp) except KeyboardInterrupt: pass @@ -647,8 +647,9 @@ def train_epoch( if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): + saver.save_recovery( - model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx) + model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)