make wandb not required but rather optional as huggingface_hub

pull/550/head
Aman Arora 4 years ago
parent f13f7508a9
commit f54897cc0b

@ -1,4 +1,3 @@
torch>=1.4.0 torch>=1.4.0
torchvision>=0.5.0 torchvision>=0.5.0
pyyaml pyyaml
wandb

@ -4,9 +4,11 @@ Hacked together by / Copyright 2020 Ross Wightman
""" """
import csv import csv
import os import os
import wandb
from collections import OrderedDict from collections import OrderedDict
try:
import wandb
except ImportError:
pass
def get_outdir(path, *paths, inc=False): def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths) outdir = os.path.join(path, *paths)
@ -28,8 +30,6 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa
rowd = OrderedDict(epoch=epoch) rowd = OrderedDict(epoch=epoch)
rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
if log_wandb:
wandb.log(rowd)
with open(filename, mode='a') as cf: with open(filename, mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=rowd.keys()) dw = csv.DictWriter(cf, fieldnames=rowd.keys())
if write_header: # first iteration (epoch == 1 can't be used) if write_header: # first iteration (epoch == 1 can't be used)

@ -23,8 +23,6 @@ from collections import OrderedDict
from contextlib import suppress from contextlib import suppress
from datetime import datetime from datetime import datetime
import wandb
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.utils import torchvision.utils
@ -54,6 +52,12 @@ try:
except AttributeError: except AttributeError:
pass pass
try:
import wandb
has_wandb = True
except ModuleNotFoundError:
has_wandb = False
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train') _logger = logging.getLogger('train')
@ -274,7 +278,7 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa
parser.add_argument('--torchscript', dest='torchscript', action='store_true', parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference') help='convert model torchscript for inference')
parser.add_argument('--log-wandb', action='store_true', default=False, parser.add_argument('--log-wandb', action='store_true', default=False,
help='use wandb for training and validation logs') help='log training and validation metrics to wandb')
def _parse_args(): def _parse_args():
@ -299,7 +303,11 @@ def main():
args, args_text = _parse_args() args, args_text = _parse_args()
if args.log_wandb: if args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args) wandb.init(project=args.experiment, config=args)
else:
_logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
args.distributed = False args.distributed = False
@ -600,7 +608,7 @@ def main():
update_summary( update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None, log_wandb=args.log_wandb) write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
if saver is not None: if saver is not None:
# save proper checkpoint with eval metric # save proper checkpoint with eval metric

Loading…
Cancel
Save