|
|
|
@ -560,7 +560,7 @@ def main():
|
|
|
|
|
best_metric = None
|
|
|
|
|
best_epoch = None
|
|
|
|
|
saver = None
|
|
|
|
|
output_dir = ''
|
|
|
|
|
output_dir = None
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
if args.experiment:
|
|
|
|
|
exp_name = args.experiment
|
|
|
|
@ -606,9 +606,10 @@ def main():
|
|
|
|
|
# step LR for next epoch
|
|
|
|
|
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
|
|
|
|
|
|
|
|
|
|
update_summary(
|
|
|
|
|
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
|
|
|
|
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
|
|
|
|
|
if output_dir is not None:
|
|
|
|
|
update_summary(
|
|
|
|
|
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
|
|
|
|
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
|
|
|
|
|
|
|
|
|
|
if saver is not None:
|
|
|
|
|
# save proper checkpoint with eval metric
|
|
|
|
@ -623,7 +624,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
def train_one_epoch(
|
|
|
|
|
epoch, model, loader, optimizer, loss_fn, args,
|
|
|
|
|
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
|
|
|
|
|
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
|
|
|
|
|
loss_scaler=None, model_ema=None, mixup_fn=None):
|
|
|
|
|
|
|
|
|
|
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
|
|
|
|