|
|
|
@ -120,7 +120,9 @@ parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
|
|
|
|
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
|
|
|
|
help='learning rate (default: None => --lr-base')
|
|
|
|
|
parser.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
|
|
|
|
|
help='base learning rate: lr = lr_base * global_batch_size / 256')
|
|
|
|
|
help='base learning rate: lr = lr_base * global_batch_size / base_size')
|
|
|
|
|
parser.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
|
|
|
|
|
help='base learning rate batch size (divisor, default: 256).')
|
|
|
|
|
parser.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
|
|
|
|
|
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
|
|
|
|
|
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
|
|
|
@ -445,7 +447,7 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
|
|
|
|
|
|
|
|
|
|
if args.lr is None:
|
|
|
|
|
global_batch_size = args.batch_size * dev_env.world_size
|
|
|
|
|
batch_ratio = global_batch_size / 256
|
|
|
|
|
batch_ratio = global_batch_size / args.lr_base_size
|
|
|
|
|
if not args.lr_base_scale:
|
|
|
|
|
on = args.opt.lower()
|
|
|
|
|
args.lr_base_scale = 'sqrt' if any([o in on for o in ('adam', 'lamb', 'adabelief')]) else 'linear'
|
|
|
|
|