diff --git a/train.py b/train.py index 3acb9909..7ac84a81 100755 --- a/train.py +++ b/train.py @@ -120,7 +120,9 @@ parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (default: None => --lr-base') parser.add_argument('--lr-base', type=float, default=0.1, metavar='LR', - help='base learning rate: lr = lr_base * global_batch_size / 256') + help='base learning rate: lr = lr_base * global_batch_size / base_size') +parser.add_argument('--lr-base-size', type=int, default=256, metavar='DIV', + help='base learning rate batch size (divisor, default: 256).') parser.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE', help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)') parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', @@ -445,7 +447,7 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool): if args.lr is None: global_batch_size = args.batch_size * dev_env.world_size - batch_ratio = global_batch_size / 256 + batch_ratio = global_batch_size / args.lr_base_size if not args.lr_base_scale: on = args.opt.lower() args.lr_base_scale = 'sqrt' if any([o in on for o in ('adam', 'lamb', 'adabelief')]) else 'linear'