From abf3e044bb44571a03f218ab6608ff4fb03b2782 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Sat, 14 Aug 2021 22:53:17 +0200 Subject: [PATCH] Update scheduler_factory.py remove duplicate code from create_scheduler() --- timm/scheduler/scheduler_factory.py | 31 ++++++++++------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index 5f5a4ca8..51b65e00 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -21,6 +21,12 @@ def create_scheduler(args, optimizer): noise_range = lr_noise * num_epochs else: noise_range = None + noise_args = dict( + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), + ) lr_scheduler = None if args.sched == 'cosine': @@ -34,10 +40,7 @@ def create_scheduler(args, optimizer): warmup_t=args.warmup_epochs, cycle_limit=getattr(args, 'lr_cycle_limit', 1), t_in_epochs=True, - noise_range_t=noise_range, - noise_pct=getattr(args, 'lr_noise_pct', 0.67), - noise_std=getattr(args, 'lr_noise_std', 1.), - noise_seed=getattr(args, 'seed', 42), + **noise_args, ) num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs elif args.sched == 'tanh': @@ -50,10 +53,7 @@ def create_scheduler(args, optimizer): warmup_t=args.warmup_epochs, cycle_limit=getattr(args, 'lr_cycle_limit', 1), t_in_epochs=True, - noise_range_t=noise_range, - noise_pct=getattr(args, 'lr_noise_pct', 0.67), - noise_std=getattr(args, 'lr_noise_std', 1.), - noise_seed=getattr(args, 'seed', 42), + **noise_args, ) num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs elif args.sched == 'step': @@ -63,10 +63,7 @@ def create_scheduler(args, optimizer): decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, - noise_range_t=noise_range, - noise_pct=getattr(args, 'lr_noise_pct', 0.67), - noise_std=getattr(args, 'lr_noise_std', 1.), - noise_seed=getattr(args, 'seed', 42), + **noise_args, ) elif args.sched == 'multistep': lr_scheduler = MultiStepLRScheduler( @@ -75,10 +72,7 @@ def create_scheduler(args, optimizer): decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, - noise_range_t=noise_range, - noise_pct=getattr(args, 'lr_noise_pct', 0.67), - noise_std=getattr(args, 'lr_noise_std', 1.), - noise_seed=getattr(args, 'seed', 42), + **noise_args, ) elif args.sched == 'plateau': mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' @@ -91,10 +85,7 @@ def create_scheduler(args, optimizer): warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, cooldown_t=0, - noise_range_t=noise_range, - noise_pct=getattr(args, 'lr_noise_pct', 0.67), - noise_std=getattr(args, 'lr_noise_std', 1.), - noise_seed=getattr(args, 'seed', 42), + **noise_args, ) return lr_scheduler, num_epochs