From 624c9b6949499a60653df6285d29852f36ab70a8 Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Thu, 8 Apr 2021 03:40:22 -0400 Subject: [PATCH] log to wandb only if using using wandb --- timm/utils/summary.py | 4 +++- train.py | 6 +----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/timm/utils/summary.py b/timm/utils/summary.py index a0801eaa..44a89afb 100644 --- a/timm/utils/summary.py +++ b/timm/utils/summary.py @@ -23,10 +23,12 @@ def get_outdir(path, *paths, inc=False): return outdir -def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False): +def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False): rowd = OrderedDict(epoch=epoch) rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) + if log_wandb: + wandb.log(rowd) with open(filename, mode='a') as cf: dw = csv.DictWriter(cf, fieldnames=rowd.keys()) if write_header: # first iteration (epoch == 1 can't be used) diff --git a/train.py b/train.py index 29fcc610..631815ac 100755 --- a/train.py +++ b/train.py @@ -592,10 +592,6 @@ def main(): eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) - if args.use_wandb: - wandb.log(train_metrics) - wandb.log(eval_metrics) - if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') @@ -609,7 +605,7 @@ def main(): update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), - write_header=best_metric is None) + write_header=best_metric is None, log_wandb=args.use_wandb) if saver is not None: # save proper checkpoint with eval metric