diff --git a/train.py b/train.py index 047a8256..444be066 100755 --- a/train.py +++ b/train.py @@ -347,13 +347,6 @@ def main(): utils.setup_default_logging() args, args_text = _parse_args() - if args.log_wandb: - if has_wandb: - wandb.init(project=args.experiment, config=args) - else: - _logger.warning("You've requested to log metrics to wandb but package not found. " - "Metrics not being logged to wandb, try `pip install wandb`") - args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: @@ -373,6 +366,13 @@ def main(): _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 + if args.rank == 0 and args.log_wandb: + if has_wandb: + wandb.init(project=args.experiment, config=args) + else: + _logger.warning("You've requested to log metrics to wandb but package not found. " + "Metrics not being logged to wandb, try `pip install wandb`") + # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: