From 8e6fb861e48f9e366d1f93637c024148110b97ea Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Thu, 8 Apr 2021 03:22:29 -0400 Subject: [PATCH 01/10] Add wandb support --- train.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 89ade4a1..9831ac76 100755 --- a/train.py +++ b/train.py @@ -23,6 +23,8 @@ from collections import OrderedDict from contextlib import suppress from datetime import datetime +import wandb + import torch import torch.nn as nn import torchvision.utils @@ -293,7 +295,8 @@ def _parse_args(): def main(): setup_default_logging() args, args_text = _parse_args() - + wandb.init(project='efficientnet_v2', config=args) + wandb.run.name = args.model args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: @@ -572,14 +575,14 @@ def main(): epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) - + wandb.log(train_metrics) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) - + wandb.log(eval_metrics) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') @@ -711,7 +714,7 @@ def train_one_epoch( if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() - return OrderedDict([('loss', losses_m.avg)]) + return OrderedDict([('train_loss', losses_m.avg)]) def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): @@ -773,7 +776,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') log_name, batch_idx, last_idx, batch_time=batch_time_m, loss=losses_m, top1=top1_m, top5=top5_m)) - metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) + metrics = OrderedDict([('val_loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) return metrics From 00c8e0b8bdb43b7296f33e147685129e7fbbf3db Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Thu, 8 Apr 2021 03:35:59 -0400 Subject: [PATCH 02/10] Make use of wandb configurable --- train.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 9831ac76..29fcc610 100755 --- a/train.py +++ b/train.py @@ -273,6 +273,10 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa help='use the multi-epochs-loader to save time at the beginning of every epoch') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') +parser.add_argument('--use-wandb', action='store_true', default=False, + help='use wandb for training and validation logs') +parser.add_argument('--wandb-project-name', type=str, default=None, + help='wandb project name to be used') def _parse_args(): @@ -295,8 +299,13 @@ def _parse_args(): def main(): setup_default_logging() args, args_text = _parse_args() - wandb.init(project='efficientnet_v2', config=args) - wandb.run.name = args.model + + if args.use_wandb: + if not args.wandb_project_name: + args.wandb_project_name = args.model + _logger.warning(f"Wandb project name not provided, defaulting to {args.model}") + wandb.init(project=args.wandb_project_name, config=args) + args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: @@ -575,14 +584,18 @@ def main(): epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) - wandb.log(train_metrics) + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) - wandb.log(eval_metrics) + + if args.use_wandb: + wandb.log(train_metrics) + wandb.log(eval_metrics) + if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') From 624c9b6949499a60653df6285d29852f36ab70a8 Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Thu, 8 Apr 2021 03:40:22 -0400 Subject: [PATCH 03/10] log to wandb only if using using wandb --- timm/utils/summary.py | 4 +++- train.py | 6 +----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/timm/utils/summary.py b/timm/utils/summary.py index a0801eaa..44a89afb 100644 --- a/timm/utils/summary.py +++ b/timm/utils/summary.py @@ -23,10 +23,12 @@ def get_outdir(path, *paths, inc=False): return outdir -def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False): +def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False): rowd = OrderedDict(epoch=epoch) rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) + if log_wandb: + wandb.log(rowd) with open(filename, mode='a') as cf: dw = csv.DictWriter(cf, fieldnames=rowd.keys()) if write_header: # first iteration (epoch == 1 can't be used) diff --git a/train.py b/train.py index 29fcc610..631815ac 100755 --- a/train.py +++ b/train.py @@ -592,10 +592,6 @@ def main(): eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) - if args.use_wandb: - wandb.log(train_metrics) - wandb.log(eval_metrics) - if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') @@ -609,7 +605,7 @@ def main(): update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), - write_header=best_metric is None) + write_header=best_metric is None, log_wandb=args.use_wandb) if saver is not None: # save proper checkpoint with eval metric From a9e5d9e5adbb674a5e4b71e0b47084259e81bb90 Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Thu, 8 Apr 2021 03:41:40 -0400 Subject: [PATCH 04/10] log loss as before --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 631815ac..7c6f9d4b 100755 --- a/train.py +++ b/train.py @@ -723,7 +723,7 @@ def train_one_epoch( if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() - return OrderedDict([('train_loss', losses_m.avg)]) + return OrderedDict([('loss', losses_m.avg)]) def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): @@ -785,7 +785,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') log_name, batch_idx, last_idx, batch_time=batch_time_m, loss=losses_m, top1=top1_m, top5=top5_m)) - metrics = OrderedDict([('val_loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) + metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) return metrics From 3f028ebc0f8cfeb1bb70d821ff3fdcc1a2f64173 Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Thu, 8 Apr 2021 03:48:51 -0400 Subject: [PATCH 05/10] import wandb in summary.py --- timm/utils/summary.py | 1 + train.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/timm/utils/summary.py b/timm/utils/summary.py index 44a89afb..10e317c5 100644 --- a/timm/utils/summary.py +++ b/timm/utils/summary.py @@ -4,6 +4,7 @@ Hacked together by / Copyright 2020 Ross Wightman """ import csv import os +import wandb from collections import OrderedDict diff --git a/train.py b/train.py index 7c6f9d4b..172008a2 100755 --- a/train.py +++ b/train.py @@ -302,8 +302,8 @@ def main(): if args.use_wandb: if not args.wandb_project_name: - args.wandb_project_name = args.model - _logger.warning(f"Wandb project name not provided, defaulting to {args.model}") + args.wandb_project_name = f'timm_{args.model}' + _logger.warning(f"Wandb project name not provided, defaulting to timm_{args.model}") wandb.init(project=args.wandb_project_name, config=args) args.prefetcher = not args.no_prefetcher From 8db8ff346fcc405c8bdaf62ac8f6574253402ed4 Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Thu, 8 Apr 2021 03:52:14 -0400 Subject: [PATCH 06/10] add wandb to requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 2d29a27c..7fa06ee5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch>=1.4.0 torchvision>=0.5.0 pyyaml +wandb \ No newline at end of file From f8bb13f64077f6041b36ab4fdb8ca12a7d5b63b2 Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Sat, 10 Apr 2021 00:44:05 -0400 Subject: [PATCH 07/10] Default project name to None --- train.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/train.py b/train.py index 172008a2..6058a05b 100755 --- a/train.py +++ b/train.py @@ -301,9 +301,6 @@ def main(): args, args_text = _parse_args() if args.use_wandb: - if not args.wandb_project_name: - args.wandb_project_name = f'timm_{args.model}' - _logger.warning(f"Wandb project name not provided, defaulting to timm_{args.model}") wandb.init(project=args.wandb_project_name, config=args) args.prefetcher = not args.no_prefetcher From f13f7508a9d68d26853963ad23d9692172d2a467 Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Sat, 10 Apr 2021 00:50:52 -0400 Subject: [PATCH 08/10] Keep changes to minimal and use args.experiment as wandb project name if it exists --- train.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/train.py b/train.py index 6058a05b..2483531b 100755 --- a/train.py +++ b/train.py @@ -273,10 +273,8 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa help='use the multi-epochs-loader to save time at the beginning of every epoch') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') -parser.add_argument('--use-wandb', action='store_true', default=False, +parser.add_argument('--log-wandb', action='store_true', default=False, help='use wandb for training and validation logs') -parser.add_argument('--wandb-project-name', type=str, default=None, - help='wandb project name to be used') def _parse_args(): @@ -300,8 +298,8 @@ def main(): setup_default_logging() args, args_text = _parse_args() - if args.use_wandb: - wandb.init(project=args.wandb_project_name, config=args) + if args.log_wandb: + wandb.init(project=args.experiment, config=args) args.prefetcher = not args.no_prefetcher args.distributed = False @@ -602,7 +600,7 @@ def main(): update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), - write_header=best_metric is None, log_wandb=args.use_wandb) + write_header=best_metric is None, log_wandb=args.log_wandb) if saver is not None: # save proper checkpoint with eval metric From f54897cc0ba741eea2e20bd54f9f0d1480a1711f Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Sat, 10 Apr 2021 01:27:23 -0400 Subject: [PATCH 09/10] make wandb not required but rather optional as huggingface_hub --- requirements.txt | 3 +-- timm/utils/summary.py | 8 ++++---- train.py | 20 ++++++++++++++------ 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7fa06ee5..251cb4a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ torch>=1.4.0 torchvision>=0.5.0 -pyyaml -wandb \ No newline at end of file +pyyaml \ No newline at end of file diff --git a/timm/utils/summary.py b/timm/utils/summary.py index 10e317c5..f6625835 100644 --- a/timm/utils/summary.py +++ b/timm/utils/summary.py @@ -4,9 +4,11 @@ Hacked together by / Copyright 2020 Ross Wightman """ import csv import os -import wandb from collections import OrderedDict - +try: + import wandb +except ImportError: + pass def get_outdir(path, *paths, inc=False): outdir = os.path.join(path, *paths) @@ -28,8 +30,6 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa rowd = OrderedDict(epoch=epoch) rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) - if log_wandb: - wandb.log(rowd) with open(filename, mode='a') as cf: dw = csv.DictWriter(cf, fieldnames=rowd.keys()) if write_header: # first iteration (epoch == 1 can't be used) diff --git a/train.py b/train.py index 2483531b..02ea20ef 100755 --- a/train.py +++ b/train.py @@ -23,8 +23,6 @@ from collections import OrderedDict from contextlib import suppress from datetime import datetime -import wandb - import torch import torch.nn as nn import torchvision.utils @@ -54,6 +52,12 @@ try: except AttributeError: pass +try: + import wandb + has_wandb = True +except ModuleNotFoundError: + has_wandb = False + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('train') @@ -274,7 +278,7 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') parser.add_argument('--log-wandb', action='store_true', default=False, - help='use wandb for training and validation logs') + help='log training and validation metrics to wandb') def _parse_args(): @@ -299,8 +303,12 @@ def main(): args, args_text = _parse_args() if args.log_wandb: - wandb.init(project=args.experiment, config=args) - + if has_wandb: + wandb.init(project=args.experiment, config=args) + else: + _logger.warning("You've requested to log metrics to wandb but package not found. " + "Metrics not being logged to wandb, try `pip install wandb`") + args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: @@ -600,7 +608,7 @@ def main(): update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), - write_header=best_metric is None, log_wandb=args.log_wandb) + write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) if saver is not None: # save proper checkpoint with eval metric From 5772c55c5781b46cde843cbfee17ec77fe3ec53d Mon Sep 17 00:00:00 2001 From: Aman Arora Date: Sat, 10 Apr 2021 01:34:20 -0400 Subject: [PATCH 10/10] Make wandb optional --- requirements.txt | 2 +- timm/utils/summary.py | 2 ++ train.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 251cb4a3..2d29a27c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ torch>=1.4.0 torchvision>=0.5.0 -pyyaml \ No newline at end of file +pyyaml diff --git a/timm/utils/summary.py b/timm/utils/summary.py index f6625835..9f5af9a0 100644 --- a/timm/utils/summary.py +++ b/timm/utils/summary.py @@ -30,6 +30,8 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa rowd = OrderedDict(epoch=epoch) rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) + if log_wandb: + wandb.log(rowd) with open(filename, mode='a') as cf: dw = csv.DictWriter(cf, fieldnames=rowd.keys()) if write_header: # first iteration (epoch == 1 can't be used) diff --git a/train.py b/train.py index 02ea20ef..bf17364e 100755 --- a/train.py +++ b/train.py @@ -55,7 +55,7 @@ except AttributeError: try: import wandb has_wandb = True -except ModuleNotFoundError: +except ImportError: has_wandb = False torch.backends.cudnn.benchmark = True