Minor updates

pull/1/head
Ross Wightman 5 years ago
parent cf0c280e1b
commit a336e5bff3

@ -1,3 +1,4 @@
from .cosine_lr import CosineLRScheduler
from .plateau_lr import PlateauLRScheduler
from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler

@ -1,6 +1,4 @@
import argparse
import csv
import os
import time
from collections import OrderedDict
from datetime import datetime
@ -218,7 +216,8 @@ def main():
lr_scheduler.step(epoch, eval_metrics['eval_loss'])
update_summary(
epoch, train_metrics, eval_metrics, output_dir, write_header=best_loss is None)
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_loss is None)
# save proper checkpoint with eval metric
best_loss = saver.save_checkpoint({

@ -5,6 +5,8 @@ import numpy as np
import os
import shutil
import glob
import csv
from collections import OrderedDict
class CheckpointSaver:
@ -137,3 +139,14 @@ def get_outdir(path, *paths, inc=False):
outdir = outdir_inc
os.makedirs(outdir)
return outdir
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False):
rowd = OrderedDict(epoch=epoch)
rowd.update(train_metrics)
rowd.update(eval_metrics)
with open(filename, mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
if write_header: # first iteration (epoch == 1 can't be used)
dw.writeheader()
dw.writerow(rowd)

Loading…
Cancel
Save