diff --git a/train.py b/train.py index 4264a164..f1c1581e 100755 --- a/train.py +++ b/train.py @@ -561,7 +561,7 @@ def main(): best_epoch = None saver = None output_dir = None - if args.local_rank == 0: + if args.rank == 0: if args.experiment: exp_name = args.experiment else: