Allow changing base lr batch size from 256 via arg

pull/1239/head
Ross Wightman 2 years ago
parent 7148039f9f
commit fafece230b

@ -120,7 +120,9 @@ parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (default: None => --lr-base')
parser.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
help='base learning rate: lr = lr_base * global_batch_size / 256')
help='base learning rate: lr = lr_base * global_batch_size / base_size')
parser.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
help='base learning rate batch size (divisor, default: 256).')
parser.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
@ -445,7 +447,7 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
if args.lr is None:
global_batch_size = args.batch_size * dev_env.world_size
batch_ratio = global_batch_size / 256
batch_ratio = global_batch_size / args.lr_base_size
if not args.lr_base_scale:
on = args.opt.lower()
args.lr_base_scale = 'sqrt' if any([o in on for o in ('adam', 'lamb', 'adabelief')]) else 'linear'

Loading…
Cancel
Save