diff --git a/train.py b/train.py index 0405f941..3acb9909 100755 --- a/train.py +++ b/train.py @@ -452,8 +452,9 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool): if args.lr_base_scale == 'sqrt': batch_ratio = batch_ratio ** 0.5 args.lr = args.lr_base * batch_ratio - _logger.info(f'Calculated learning rate ({args.lr}) from base learning rate ({args.lr_base}) ' - f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.') + if dev_env.primary: + _logger.info(f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) ' + f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.') train_state = setup_model_and_optimizer( dev_env=dev_env,