diff --git a/train.py b/train.py index 6e74f67b..0405f941 100755 --- a/train.py +++ b/train.py @@ -117,8 +117,12 @@ parser.add_argument('--clip-mode', type=str, default='norm', # Learning rate schedule parameters parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "cosine"') -parser.add_argument('--lr', type=float, default=0.1, metavar='LR', - help='learning rate (default: 0.05)') +parser.add_argument('--lr', type=float, default=None, metavar='LR', + help='learning rate (default: None => --lr-base') +parser.add_argument('--lr-base', type=float, default=0.1, metavar='LR', + help='base learning rate: lr = lr_base * global_batch_size / 256') +parser.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE', + help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)') parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', @@ -165,7 +169,7 @@ parser.add_argument('--hflip', type=float, default=0.5, help='Horizontal flip training aug probability') parser.add_argument('--vflip', type=float, default=0., help='Vertical flip training aug probability') -parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', +parser.add_argument('--color-jitter', type=float, default=None, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)'), @@ -439,6 +443,18 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool): assert args.aug_splits > 1 model = convert_splitbn_model(model, max(args.aug_splits, 2)) + if args.lr is None: + global_batch_size = args.batch_size * dev_env.world_size + batch_ratio = global_batch_size / 256 + if not args.lr_base_scale: + on = args.opt.lower() + args.lr_base_scale = 'sqrt' if any([o in on for o in ('adam', 'lamb', 'adabelief')]) else 'linear' + if args.lr_base_scale == 'sqrt': + batch_ratio = batch_ratio ** 0.5 + args.lr = args.lr_base * batch_ratio + _logger.info(f'Calculated learning rate ({args.lr}) from base learning rate ({args.lr_base}) ' + f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.') + train_state = setup_model_and_optimizer( dev_env=dev_env, model=model,