|
|
|
@ -544,7 +544,7 @@ def main():
|
|
|
|
|
save_metric = eval_metrics[eval_metric]
|
|
|
|
|
best_metric, best_epoch = saver.save_checkpoint(
|
|
|
|
|
model, optimizer, args,
|
|
|
|
|
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=use_amp)
|
|
|
|
|
epoch=epoch, model_ema=model_ema, metric=save_metric, use_amp=has_apex&use_amp)
|
|
|
|
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
pass
|
|
|
|
@ -647,8 +647,9 @@ def train_epoch(
|
|
|
|
|
|
|
|
|
|
if saver is not None and args.recovery_interval and (
|
|
|
|
|
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
|
|
|
|
|
|
|
|
|
saver.save_recovery(
|
|
|
|
|
model, optimizer, args, epoch, model_ema=model_ema, use_amp=use_amp, batch_idx=batch_idx)
|
|
|
|
|
model, optimizer, args, epoch, model_ema=model_ema, use_amp=has_apex&use_amp, batch_idx=batch_idx)
|
|
|
|
|
|
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
|
|
|
|