diff --git a/requirements.txt b/requirements.txt index 7fa06ee5..251cb4a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ torch>=1.4.0 torchvision>=0.5.0 -pyyaml -wandb \ No newline at end of file +pyyaml \ No newline at end of file diff --git a/timm/utils/summary.py b/timm/utils/summary.py index 10e317c5..f6625835 100644 --- a/timm/utils/summary.py +++ b/timm/utils/summary.py @@ -4,9 +4,11 @@ Hacked together by / Copyright 2020 Ross Wightman """ import csv import os -import wandb from collections import OrderedDict - +try: + import wandb +except ImportError: + pass def get_outdir(path, *paths, inc=False): outdir = os.path.join(path, *paths) @@ -28,8 +30,6 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa rowd = OrderedDict(epoch=epoch) rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) - if log_wandb: - wandb.log(rowd) with open(filename, mode='a') as cf: dw = csv.DictWriter(cf, fieldnames=rowd.keys()) if write_header: # first iteration (epoch == 1 can't be used) diff --git a/train.py b/train.py index 2483531b..02ea20ef 100755 --- a/train.py +++ b/train.py @@ -23,8 +23,6 @@ from collections import OrderedDict from contextlib import suppress from datetime import datetime -import wandb - import torch import torch.nn as nn import torchvision.utils @@ -54,6 +52,12 @@ try: except AttributeError: pass +try: + import wandb + has_wandb = True +except ModuleNotFoundError: + has_wandb = False + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('train') @@ -274,7 +278,7 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') parser.add_argument('--log-wandb', action='store_true', default=False, - help='use wandb for training and validation logs') + help='log training and validation metrics to wandb') def _parse_args(): @@ -299,8 +303,12 @@ def main(): args, args_text = _parse_args() if args.log_wandb: - wandb.init(project=args.experiment, config=args) - + if has_wandb: + wandb.init(project=args.experiment, config=args) + else: + _logger.warning("You've requested to log metrics to wandb but package not found. " + "Metrics not being logged to wandb, try `pip install wandb`") + args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: @@ -600,7 +608,7 @@ def main(): update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), - write_header=best_metric is None, log_wandb=args.log_wandb) + write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric