|
|
|
@ -30,6 +30,8 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa
|
|
|
|
|
rowd = OrderedDict(epoch=epoch)
|
|
|
|
|
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
|
|
|
|
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
|
|
|
|
if log_wandb:
|
|
|
|
|
wandb.log(rowd)
|
|
|
|
|
with open(filename, mode='a') as cf:
|
|
|
|
|
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
|
|
|
|
if write_header: # first iteration (epoch == 1 can't be used)
|
|
|
|
|