Make min-lr and cooldown-epochs cmdline args, change dash in color_jitter arg for consistency

pull/23/head
Ross Wightman 5 years ago
parent d4debe6597
commit e7c8a37334

@ -11,26 +11,26 @@ def create_scheduler(args, optimizer):
optimizer, optimizer,
t_initial=num_epochs, t_initial=num_epochs,
t_mul=1.0, t_mul=1.0,
lr_min=1e-5, lr_min=args.min_lr,
decay_rate=args.decay_rate, decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr, warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs, warmup_t=args.warmup_epochs,
cycle_limit=1, cycle_limit=1,
t_in_epochs=True, t_in_epochs=True,
) )
num_epochs = lr_scheduler.get_cycle_length() + 10 num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
elif args.sched == 'tanh': elif args.sched == 'tanh':
lr_scheduler = TanhLRScheduler( lr_scheduler = TanhLRScheduler(
optimizer, optimizer,
t_initial=num_epochs, t_initial=num_epochs,
t_mul=1.0, t_mul=1.0,
lr_min=1e-5, lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr, warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs, warmup_t=args.warmup_epochs,
cycle_limit=1, cycle_limit=1,
t_in_epochs=True, t_in_epochs=True,
) )
num_epochs = lr_scheduler.get_cycle_length() + 10 num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
else: else:
lr_scheduler = StepLRScheduler( lr_scheduler = StepLRScheduler(
optimizer, optimizer,

@ -70,6 +70,8 @@ parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)') help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)') help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=200, metavar='N', parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='number of epochs to train (default: 2)') help='number of epochs to train (default: 2)')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N', parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
@ -78,10 +80,12 @@ parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
help='epoch interval to decay LR') help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
help='epochs to warmup LR, if scheduler supports') help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)') help='LR decay rate (default: 0.1)')
# Augmentation parameters # Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)') help='Color jitter factor (default: 0.4)')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT', parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)') help='Random erase prob (default: 0.)')

Loading…
Cancel
Save