|
|
|
@ -117,8 +117,12 @@ parser.add_argument('--clip-mode', type=str, default='norm',
|
|
|
|
|
# Learning rate schedule parameters
|
|
|
|
|
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
|
|
|
|
help='LR scheduler (default: "cosine"')
|
|
|
|
|
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
|
|
|
|
|
help='learning rate (default: 0.05)')
|
|
|
|
|
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
|
|
|
|
help='learning rate (default: None => --lr-base')
|
|
|
|
|
parser.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
|
|
|
|
|
help='base learning rate: lr = lr_base * global_batch_size / 256')
|
|
|
|
|
parser.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
|
|
|
|
|
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
|
|
|
|
|
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
|
|
|
|
help='learning rate noise on/off epoch percentages')
|
|
|
|
|
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
|
|
|
@ -165,7 +169,7 @@ parser.add_argument('--hflip', type=float, default=0.5,
|
|
|
|
|
help='Horizontal flip training aug probability')
|
|
|
|
|
parser.add_argument('--vflip', type=float, default=0.,
|
|
|
|
|
help='Vertical flip training aug probability')
|
|
|
|
|
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
|
|
|
|
parser.add_argument('--color-jitter', type=float, default=None, metavar='PCT',
|
|
|
|
|
help='Color jitter factor (default: 0.4)')
|
|
|
|
|
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
|
|
|
|
|
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
|
|
|
@ -439,6 +443,18 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
|
|
|
|
|
assert args.aug_splits > 1
|
|
|
|
|
model = convert_splitbn_model(model, max(args.aug_splits, 2))
|
|
|
|
|
|
|
|
|
|
if args.lr is None:
|
|
|
|
|
global_batch_size = args.batch_size * dev_env.world_size
|
|
|
|
|
batch_ratio = global_batch_size / 256
|
|
|
|
|
if not args.lr_base_scale:
|
|
|
|
|
on = args.opt.lower()
|
|
|
|
|
args.lr_base_scale = 'sqrt' if any([o in on for o in ('adam', 'lamb', 'adabelief')]) else 'linear'
|
|
|
|
|
if args.lr_base_scale == 'sqrt':
|
|
|
|
|
batch_ratio = batch_ratio ** 0.5
|
|
|
|
|
args.lr = args.lr_base * batch_ratio
|
|
|
|
|
_logger.info(f'Calculated learning rate ({args.lr}) from base learning rate ({args.lr_base}) '
|
|
|
|
|
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
|
|
|
|
|
|
|
|
|
|
train_state = setup_model_and_optimizer(
|
|
|
|
|
dev_env=dev_env,
|
|
|
|
|
model=model,
|
|
|
|
|