fix save_checkpoint bug with native amp

pull/228/head
datamining99 4 years ago
parent d98967ed5d
commit 5f563ca4df

@ -544,7 +544,7 @@ def main():
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(
model, optimizer, args,
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp)
except KeyboardInterrupt:
pass
@ -647,8 +647,9 @@ def train_epoch(
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(
model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)

Loading…
Cancel
Save