diff --git a/train.py b/train.py index 9831ac76..29fcc610 100755 --- a/train.py +++ b/train.py @@ -273,6 +273,10 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa help='use the multi-epochs-loader to save time at the beginning of every epoch') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') +parser.add_argument('--use-wandb', action='store_true', default=False, + help='use wandb for training and validation logs') +parser.add_argument('--wandb-project-name', type=str, default=None, + help='wandb project name to be used') def _parse_args(): @@ -295,8 +299,13 @@ def _parse_args(): def main(): setup_default_logging() args, args_text = _parse_args() - wandb.init(project='efficientnet_v2', config=args) - wandb.run.name = args.model + + if args.use_wandb: + if not args.wandb_project_name: + args.wandb_project_name = args.model + _logger.warning(f"Wandb project name not provided, defaulting to {args.model}") + wandb.init(project=args.wandb_project_name, config=args) + args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: @@ -575,14 +584,18 @@ def main(): epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) - wandb.log(train_metrics) + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) - wandb.log(eval_metrics) + + if args.use_wandb: + wandb.log(train_metrics) + wandb.log(eval_metrics) + if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')