make wandb not required but rather optional as huggingface_hub

pull/550/head
Aman Arora 4 years ago
parent f13f7508a9
commit f54897cc0b

@ -1,4 +1,3 @@
torch>=1.4.0
torchvision>=0.5.0
pyyaml
wandb
pyyaml

@ -4,9 +4,11 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
import csv
import os
import wandb
from collections import OrderedDict
try:
import wandb
except ImportError:
pass
def get_outdir(path, *paths, inc=False):
outdir = os.path.join(path, *paths)
@ -28,8 +30,6 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa
rowd = OrderedDict(epoch=epoch)
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
if log_wandb:
wandb.log(rowd)
with open(filename, mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
if write_header: # first iteration (epoch == 1 can't be used)

@ -23,8 +23,6 @@ from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
import wandb
import torch
import torch.nn as nn
import torchvision.utils
@ -54,6 +52,12 @@ try:
except AttributeError:
pass
try:
import wandb
has_wandb = True
except ModuleNotFoundError:
has_wandb = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
@ -274,7 +278,7 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--log-wandb', action='store_true', default=False,
help='use wandb for training and validation logs')
help='log training and validation metrics to wandb')
def _parse_args():
@ -299,8 +303,12 @@ def main():
args, args_text = _parse_args()
if args.log_wandb:
wandb.init(project=args.experiment, config=args)
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
@ -600,7 +608,7 @@ def main():
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None, log_wandb=args.log_wandb)
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
if saver is not None:
# save proper checkpoint with eval metric

Loading…
Cancel
Save