pytorch-image-models/timm/utils/summary.py

51 lines
1.3 KiB

""" Summary utilities
Hacked together by / Copyright 2020 Ross Wightman
"""
import csv
import os
from collections import OrderedDict
try:
import wandb
except ImportError:
pass
def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths)
if not os.path.exists(outdir):
os.makedirs(outdir)
elif inc:
count = 1
outdir_inc = outdir + '-' + str(count)
while os.path.exists(outdir_inc):
count = count + 1
outdir_inc = outdir + '-' + str(count)
assert count < 100
outdir = outdir_inc
os.makedirs(outdir)
return outdir
def update_summary(
epoch,
train_metrics,
eval_metrics,
filename,
lr=None,
write_header=False,
log_wandb=False,
):
rowd = OrderedDict(epoch=epoch)
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
if lr is not None:
rowd['lr'] = lr
if log_wandb:
wandb.log(rowd)
with open(filename, mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
if write_header: # first iteration (epoch == 1 can't be used)
dw.writeheader()
dw.writerow(rowd)