diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index e64c34d1..8f1032a1 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -11,26 +11,26 @@ def create_scheduler(args, optimizer): optimizer, t_initial=num_epochs, t_mul=1.0, - lr_min=1e-5, + lr_min=args.min_lr, decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, cycle_limit=1, t_in_epochs=True, ) - num_epochs = lr_scheduler.get_cycle_length() + 10 + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs elif args.sched == 'tanh': lr_scheduler = TanhLRScheduler( optimizer, t_initial=num_epochs, t_mul=1.0, - lr_min=1e-5, + lr_min=args.min_lr, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, cycle_limit=1, t_in_epochs=True, ) - num_epochs = lr_scheduler.get_cycle_length() + 10 + num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs else: lr_scheduler = StepLRScheduler( optimizer, diff --git a/train.py b/train.py index f7ecdd5d..51006a0d 100644 --- a/train.py +++ b/train.py @@ -70,6 +70,8 @@ parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', help='warmup learning rate (default: 0.0001)') +parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 2)') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', @@ -78,10 +80,12 @@ parser.add_argument('--decay-epochs', type=int, default=30, metavar='N', help='epoch interval to decay LR') parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', help='epochs to warmup LR, if scheduler supports') +parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', + help='epochs to cooldown LR at min_lr, after cyclic schedule ends') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') # Augmentation parameters -parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', +parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)')