From a336e5bff371484afb822ade1d37afd476f41cdf Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 8 Feb 2019 20:56:24 -0800 Subject: [PATCH] Minor updates --- scheduler/__init__.py | 1 + train.py | 5 ++--- utils.py | 13 +++++++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/scheduler/__init__.py b/scheduler/__init__.py index 73f9c78d..8242163f 100644 --- a/scheduler/__init__.py +++ b/scheduler/__init__.py @@ -1,3 +1,4 @@ from .cosine_lr import CosineLRScheduler from .plateau_lr import PlateauLRScheduler from .step_lr import StepLRScheduler +from .tanh_lr import TanhLRScheduler \ No newline at end of file diff --git a/train.py b/train.py index 63007370..2adc54b3 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,4 @@ import argparse -import csv -import os import time from collections import OrderedDict from datetime import datetime @@ -218,7 +216,8 @@ def main(): lr_scheduler.step(epoch, eval_metrics['eval_loss']) update_summary( - epoch, train_metrics, eval_metrics, output_dir, write_header=best_loss is None) + epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), + write_header=best_loss is None) # save proper checkpoint with eval metric best_loss = saver.save_checkpoint({ diff --git a/utils.py b/utils.py index ad5f7780..4604a258 100644 --- a/utils.py +++ b/utils.py @@ -5,6 +5,8 @@ import numpy as np import os import shutil import glob +import csv +from collections import OrderedDict class CheckpointSaver: @@ -137,3 +139,14 @@ def get_outdir(path, *paths, inc=False): outdir = outdir_inc os.makedirs(outdir) return outdir + + +def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False): + rowd = OrderedDict(epoch=epoch) + rowd.update(train_metrics) + rowd.update(eval_metrics) + with open(filename, mode='a') as cf: + dw = csv.DictWriter(cf, fieldnames=rowd.keys()) + if write_header: # first iteration (epoch == 1 can't be used) + dw.writeheader() + dw.writerow(rowd)