Add base lr w/ linear and sqrt scaling to train script

pull/1239/head
Ross Wightman 2 years ago
parent 066e490605
commit f82fb6b608

@ -117,8 +117,12 @@ parser.add_argument('--clip-mode', type=str, default='norm',
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (default: None => --lr-base')
parser.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
help='base learning rate: lr = lr_base * global_batch_size / 256')
parser.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
@ -165,7 +169,7 @@ parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
parser.add_argument('--color-jitter', type=float, default=None, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
@ -439,6 +443,18 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
assert args.aug_splits > 1
model = convert_splitbn_model(model, max(args.aug_splits, 2))
if args.lr is None:
global_batch_size = args.batch_size * dev_env.world_size
batch_ratio = global_batch_size / 256
if not args.lr_base_scale:
on = args.opt.lower()
args.lr_base_scale = 'sqrt' if any([o in on for o in ('adam', 'lamb', 'adabelief')]) else 'linear'
if args.lr_base_scale == 'sqrt':
batch_ratio = batch_ratio ** 0.5
args.lr = args.lr_base * batch_ratio
_logger.info(f'Calculated learning rate ({args.lr}) from base learning rate ({args.lr_base}) '
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
train_state = setup_model_and_optimizer(
dev_env=dev_env,
model=model,

Loading…
Cancel
Save