|
|
|
@ -452,8 +452,9 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
|
|
|
|
|
if args.lr_base_scale == 'sqrt':
|
|
|
|
|
batch_ratio = batch_ratio ** 0.5
|
|
|
|
|
args.lr = args.lr_base * batch_ratio
|
|
|
|
|
_logger.info(f'Calculated learning rate ({args.lr}) from base learning rate ({args.lr_base}) '
|
|
|
|
|
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
|
|
|
|
|
if dev_env.primary:
|
|
|
|
|
_logger.info(f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
|
|
|
|
|
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
|
|
|
|
|
|
|
|
|
|
train_state = setup_model_and_optimizer(
|
|
|
|
|
dev_env=dev_env,
|
|
|
|
|