You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
1.3 KiB
51 lines
1.3 KiB
""" Summary utilities
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
"""
|
|
import csv
|
|
import os
|
|
from collections import OrderedDict
|
|
try:
|
|
import wandb
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def get_outdir(path, *paths, inc=False):
|
|
outdir = os.path.join(path, *paths)
|
|
if not os.path.exists(outdir):
|
|
os.makedirs(outdir)
|
|
elif inc:
|
|
count = 1
|
|
outdir_inc = outdir + '-' + str(count)
|
|
while os.path.exists(outdir_inc):
|
|
count = count + 1
|
|
outdir_inc = outdir + '-' + str(count)
|
|
assert count < 100
|
|
outdir = outdir_inc
|
|
os.makedirs(outdir)
|
|
return outdir
|
|
|
|
|
|
def update_summary(
|
|
epoch,
|
|
train_metrics,
|
|
eval_metrics,
|
|
filename,
|
|
lr=None,
|
|
write_header=False,
|
|
log_wandb=False,
|
|
):
|
|
rowd = OrderedDict(epoch=epoch)
|
|
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
|
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
|
if lr is not None:
|
|
rowd['lr'] = lr
|
|
if log_wandb:
|
|
wandb.log(rowd)
|
|
with open(filename, mode='a') as cf:
|
|
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
|
if write_header: # first iteration (epoch == 1 can't be used)
|
|
dw.writeheader()
|
|
dw.writerow(rowd)
|