From 514b0938c47632744daa140de5668ed7357b1f9d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 21 Feb 2020 11:51:05 -0800 Subject: [PATCH 1/4] Experimenting with per-epoch learning rate noise w/ step scheduler --- timm/scheduler/scheduler_factory.py | 7 +++++++ timm/scheduler/step_lr.py | 21 ++++++++++++++++++--- train.py | 4 ++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index 80d37b37..dca8a580 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -33,11 +33,18 @@ def create_scheduler(args, optimizer): ) num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs elif args.sched == 'step': + if isinstance(args.lr_noise, (list, tuple)): + noise_range = [n * num_epochs for n in args.lr_noise] + else: + noise_range = args.lr_noise * num_epochs + print(noise_range) lr_scheduler = StepLRScheduler( optimizer, decay_t=args.decay_epochs, decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, + noise_range_t=noise_range, + noise_std=args.lr_noise_std, ) return lr_scheduler, num_epochs diff --git a/timm/scheduler/step_lr.py b/timm/scheduler/step_lr.py index 5ee8b90f..d3060fd8 100644 --- a/timm/scheduler/step_lr.py +++ b/timm/scheduler/step_lr.py @@ -14,14 +14,19 @@ class StepLRScheduler(Scheduler): decay_rate: float = 1., warmup_t=0, warmup_lr_init=0, + noise_range_t=None, + noise_std=1.0, t_in_epochs=True, - initialize=True) -> None: + initialize=True, + ) -> None: super().__init__(optimizer, param_group_field="lr", initialize=initialize) self.decay_t = decay_t self.decay_rate = decay_rate self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init + self.noise_range_t = noise_range_t + self.noise_std = noise_std self.t_in_epochs = t_in_epochs if self.warmup_t: self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] @@ -33,8 +38,18 @@ class StepLRScheduler(Scheduler): if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: - lrs = [v * (self.decay_rate ** (t // self.decay_t)) - for v in self.base_values] + lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] + if self.noise_range_t is not None: + if isinstance(self.noise_range_t, (list, tuple)): + apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] + else: + apply_noise = t >= self.noise_range_t + if apply_noise: + g = torch.Generator() + g.manual_seed(t) + lr_mult = torch.randn(1, generator=g).item() * self.noise_std + 1. + lrs = [min(5 * v, max(v / 5, v * lr_mult)) for v in lrs] + print(lrs) return lrs def get_epoch_values(self, epoch: int): diff --git a/train.py b/train.py index 7b4e1af0..c9d73833 100755 --- a/train.py +++ b/train.py @@ -105,6 +105,10 @@ parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', help='LR scheduler (default: "step"') parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') +parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', + help='learning rate noise on/off epoch percentages') +parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', + help='learning rate nose std-dev (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', help='warmup learning rate (default: 0.0001)') parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', From 27b3680d49e0e7c30ae6dd25f0daeb8de5319b40 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 22 Feb 2020 16:23:15 -0800 Subject: [PATCH 2/4] Revamp LR noise, move logic to scheduler base. Fixup PlateauLRScheduler and add it as an option. --- timm/scheduler/cosine_lr.py | 9 +++++- timm/scheduler/plateau_lr.py | 50 ++++++++++++----------------- timm/scheduler/scheduler.py | 32 ++++++++++++++++++ timm/scheduler/scheduler_factory.py | 37 ++++++++++++++++++--- timm/scheduler/step_lr.py | 24 +++++--------- timm/scheduler/tanh_lr.py | 9 +++++- train.py | 6 +++- 7 files changed, 114 insertions(+), 53 deletions(-) diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py index f2a85931..f7cb204c 100644 --- a/timm/scheduler/cosine_lr.py +++ b/timm/scheduler/cosine_lr.py @@ -29,8 +29,15 @@ class CosineLRScheduler(Scheduler): warmup_prefix=False, cycle_limit=0, t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, initialize=True) -> None: - super().__init__(optimizer, param_group_field="lr", initialize=initialize) + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) assert t_initial > 0 assert lr_min >= 0 diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py index 0cad2159..8129459b 100644 --- a/timm/scheduler/plateau_lr.py +++ b/timm/scheduler/plateau_lr.py @@ -8,33 +8,34 @@ class PlateauLRScheduler(Scheduler): def __init__(self, optimizer, - factor=0.1, - patience=10, - verbose=False, + decay_rate=0.1, + patience_t=10, + verbose=True, threshold=1e-4, - cooldown_epochs=0, - warmup_updates=0, + cooldown_t=0, + warmup_t=0, warmup_lr_init=0, lr_min=0, + mode='min', + initialize=True, ): - super().__init__(optimizer, 'lr', initialize=False) + super().__init__(optimizer, 'lr', initialize=initialize) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - self.optimizer.optimizer, - patience=patience, - factor=factor, + self.optimizer, + patience=patience_t, + factor=decay_rate, verbose=verbose, threshold=threshold, - cooldown=cooldown_epochs, + cooldown=cooldown_t, + mode=mode, min_lr=lr_min ) - self.warmup_updates = warmup_updates + self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init - - if self.warmup_updates: - self.warmup_active = warmup_updates > 0 # this state updates with num_updates - self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values] + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] super().update_groups(self.warmup_lr_init) else: self.warmup_steps = [1 for _ in self.base_values] @@ -51,18 +52,9 @@ class PlateauLRScheduler(Scheduler): self.lr_scheduler.last_epoch = state_dict['last_epoch'] # override the base class step fn completely - def step(self, epoch, val_loss=None): - """Update the learning rate at the end of the given epoch.""" - if val_loss is not None and not self.warmup_active: - self.lr_scheduler.step(val_loss, epoch) - else: - self.lr_scheduler.last_epoch = epoch - - def get_update_values(self, num_updates: int): - if num_updates < self.warmup_updates: - lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps] + def step(self, epoch, metric=None): + if epoch <= self.warmup_t: + lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] + super().update_groups(lrs) else: - self.warmup_active = False # warmup cancelled by first update past warmup_update count - lrs = None # no change on update after warmup stage - return lrs - + self.lr_scheduler.step(metric, epoch) diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 78e8460d..21d51509 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -25,6 +25,11 @@ class Scheduler: def __init__(self, optimizer: torch.optim.Optimizer, param_group_field: str, + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, initialize: bool = True) -> None: self.optimizer = optimizer self.param_group_field = param_group_field @@ -40,6 +45,11 @@ class Scheduler: raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] self.metric = None # any point to having this for all? + self.noise_range_t = noise_range_t + self.noise_pct = noise_pct + self.noise_type = noise_type + self.noise_std = noise_std + self.noise_seed = noise_seed if noise_seed is not None else 42 self.update_groups(self.base_values) def state_dict(self) -> Dict[str, Any]: @@ -58,12 +68,14 @@ class Scheduler: self.metric = metric values = self.get_epoch_values(epoch) if values is not None: + values = self._add_noise(values, epoch) self.update_groups(values) def step_update(self, num_updates: int, metric: float = None): self.metric = metric values = self.get_update_values(num_updates) if values is not None: + values = self._add_noise(values, num_updates) self.update_groups(values) def update_groups(self, values): @@ -71,3 +83,23 @@ class Scheduler: values = [values] * len(self.optimizer.param_groups) for param_group, value in zip(self.optimizer.param_groups, values): param_group[self.param_group_field] = value + + def _add_noise(self, lrs, t): + if self.noise_range_t is not None: + if isinstance(self.noise_range_t, (list, tuple)): + apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] + else: + apply_noise = t >= self.noise_range_t + if apply_noise: + g = torch.Generator() + g.manual_seed(self.noise_seed + t) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + break + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + lrs = [v + v * noise for v in lrs] + return lrs diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index dca8a580..ffe858ad 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -1,10 +1,21 @@ from .cosine_lr import CosineLRScheduler from .tanh_lr import TanhLRScheduler from .step_lr import StepLRScheduler +from .plateau_lr import PlateauLRScheduler def create_scheduler(args, optimizer): num_epochs = args.epochs + + if args.lr_noise is not None: + if isinstance(args.lr_noise, (list, tuple)): + noise_range = [n * num_epochs for n in args.lr_noise] + else: + noise_range = args.lr_noise * num_epochs + print('Noise range:', noise_range) + else: + noise_range = None + lr_scheduler = None #FIXME expose cycle parms of the scheduler config to arguments if args.sched == 'cosine': @@ -18,6 +29,10 @@ def create_scheduler(args, optimizer): warmup_t=args.warmup_epochs, cycle_limit=1, t_in_epochs=True, + noise_range_t=noise_range, + noise_pct=args.lr_noise_pct, + noise_std=args.lr_noise_std, + noise_seed=args.seed, ) num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs elif args.sched == 'tanh': @@ -30,14 +45,13 @@ def create_scheduler(args, optimizer): warmup_t=args.warmup_epochs, cycle_limit=1, t_in_epochs=True, + noise_range_t=noise_range, + noise_pct=args.lr_noise_pct, + noise_std=args.lr_noise_std, + noise_seed=args.seed, ) num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs elif args.sched == 'step': - if isinstance(args.lr_noise, (list, tuple)): - noise_range = [n * num_epochs for n in args.lr_noise] - else: - noise_range = args.lr_noise * num_epochs - print(noise_range) lr_scheduler = StepLRScheduler( optimizer, decay_t=args.decay_epochs, @@ -45,6 +59,19 @@ def create_scheduler(args, optimizer): warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, noise_range_t=noise_range, + noise_pct=args.lr_noise_pct, noise_std=args.lr_noise_std, + noise_seed=args.seed, + ) + elif args.sched == 'plateau': + lr_scheduler = PlateauLRScheduler( + optimizer, + decay_rate=args.decay_rate, + patience_t=args.patience_epochs, + lr_min=args.min_lr, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cooldown_t=args.cooldown_epochs, ) + return lr_scheduler, num_epochs diff --git a/timm/scheduler/step_lr.py b/timm/scheduler/step_lr.py index d3060fd8..b3c75d96 100644 --- a/timm/scheduler/step_lr.py +++ b/timm/scheduler/step_lr.py @@ -10,23 +10,26 @@ class StepLRScheduler(Scheduler): def __init__(self, optimizer: torch.optim.Optimizer, - decay_t: int, + decay_t: float, decay_rate: float = 1., warmup_t=0, warmup_lr_init=0, + t_in_epochs=True, noise_range_t=None, + noise_pct=0.67, noise_std=1.0, - t_in_epochs=True, + noise_seed=42, initialize=True, ) -> None: - super().__init__(optimizer, param_group_field="lr", initialize=initialize) + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) self.decay_t = decay_t self.decay_rate = decay_rate self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init - self.noise_range_t = noise_range_t - self.noise_std = noise_std self.t_in_epochs = t_in_epochs if self.warmup_t: self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] @@ -39,17 +42,6 @@ class StepLRScheduler(Scheduler): lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values] - if self.noise_range_t is not None: - if isinstance(self.noise_range_t, (list, tuple)): - apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] - else: - apply_noise = t >= self.noise_range_t - if apply_noise: - g = torch.Generator() - g.manual_seed(t) - lr_mult = torch.randn(1, generator=g).item() * self.noise_std + 1. - lrs = [min(5 * v, max(v / 5, v * lr_mult)) for v in lrs] - print(lrs) return lrs def get_epoch_values(self, epoch: int): diff --git a/timm/scheduler/tanh_lr.py b/timm/scheduler/tanh_lr.py index cb257d0b..241727de 100644 --- a/timm/scheduler/tanh_lr.py +++ b/timm/scheduler/tanh_lr.py @@ -28,8 +28,15 @@ class TanhLRScheduler(Scheduler): warmup_prefix=False, cycle_limit=0, t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, initialize=True) -> None: - super().__init__(optimizer, param_group_field="lr", initialize=initialize) + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) assert t_initial > 0 assert lr_min >= 0 diff --git a/train.py b/train.py index c9d73833..a6fd0e47 100755 --- a/train.py +++ b/train.py @@ -107,8 +107,10 @@ parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)') parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') +parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', + help='learning rate noise limit percent (default: 0.67)') parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', - help='learning rate nose std-dev (default: 1.0)') + help='learning rate noise std-dev (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', help='warmup learning rate (default: 0.0001)') parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', @@ -123,6 +125,8 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') +parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', + help='patience epochs for Plateau LR scheduler (default: 10') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') # Augmentation parameters From 9fee316752071888a09e87d3298a13e93c841075 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 24 Feb 2020 15:11:26 -0800 Subject: [PATCH 3/4] Enable fixed fanout calc in EfficientNet/MobileNetV3 weight init by default. Fix #84 --- timm/models/efficientnet_builder.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 137705de..f8f0df8a 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -359,15 +359,13 @@ class EfficientNetBuilder: return stages -def _init_weight_goog(m, n='', fix_group_fanout=False): +def _init_weight_goog(m, n='', fix_group_fanout=True): """ Weight initialization as per Tensorflow official implementations. Args: m (nn.Module): module to init n (str): module name - fix_group_fanout (bool): enable correct fanout calculation w/ group convs - - FIXME change fix_group_fanout to default to True if experiments show better training results + fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py From c16f25ced2ff9b41895a12ba9c967b272ceb311e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 29 Feb 2020 20:37:20 -0800 Subject: [PATCH 4/4] Add MobileNetV3 Large weights, results, update README and sotabench for merge --- README.md | 16 +++++++++++++++- sotabench.py | 2 +- timm/models/mobilenetv3.py | 4 +++- timm/scheduler/scheduler_factory.py | 3 ++- 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index eaef67e8..62c93cce 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,14 @@ ## What's New +### Feb 29, 2020 +* New MobileNet-V3 Large weights trained from stratch with this code to 75.77% top-1 +* IMPORTANT CHANGE - default weight init changed for all MobilenetV3 / EfficientNet / related models + * overall results similar to a bit better training from scratch on a few smaller models tried + * performance early in training seems consistently improved but less difference by end + * set `fix_group_fanout=False` in `_init_weight_goog` fn if you need to reproducte past behaviour +* Experimental LR noise feature added applies a random perturbation to LR each epoch in specified range of training + ### Feb 18, 2020 * Big refactor of model layers and addition of several attention mechanisms. Several additions motivated by 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268): * Move layer/module impl into `layers` subfolder/module of `models` and organize in a more granular fashion @@ -187,7 +195,8 @@ I've leveraged the training scripts in this repository to train a few of the mod | skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2M | bicubic | 224 | | resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | 224 | | mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13M | bicubic | 224 | -| mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 | +| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5M | bicubic | 224 | +| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 | | mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic | 224 | | resnet26 | 75.292 (24.708) | 92.57 (7.43) | 16M | bicubic | 224 | | fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear | 224 | @@ -361,6 +370,11 @@ Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model `./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064` +### MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5 + +`./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9` + + **TODO dig up some more** diff --git a/sotabench.py b/sotabench.py index 459993bd..7b896819 100644 --- a/sotabench.py +++ b/sotabench.py @@ -93,7 +93,7 @@ model_list = [ _entry('semnasnet_100', 'MnasNet-A1', '1807.11626'), _entry('spnasnet_100', 'Single-Path NAS', '1904.02877', model_desc='Trained in PyTorch with SGD, cosine LR decay'), - _entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244', + _entry('mobilenetv3_large_100', 'MobileNet V3-Large 1.0', '1905.02244', model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching ' 'paper as closely as possible.'), diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 39391c56..fe90767c 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -31,7 +31,9 @@ def _cfg(url='', **kwargs): default_cfgs = { 'mobilenetv3_large_075': _cfg(url=''), - 'mobilenetv3_large_100': _cfg(url=''), + 'mobilenetv3_large_100': _cfg( + interpolation='bicubic', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'), 'mobilenetv3_small_075': _cfg(url=''), 'mobilenetv3_small_100': _cfg(url=''), 'mobilenetv3_rw': _cfg( diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index ffe858ad..2320c96b 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -10,9 +10,10 @@ def create_scheduler(args, optimizer): if args.lr_noise is not None: if isinstance(args.lr_noise, (list, tuple)): noise_range = [n * num_epochs for n in args.lr_noise] + if len(noise_range) == 1: + noise_range = noise_range[0] else: noise_range = args.lr_noise * num_epochs - print('Noise range:', noise_range) else: noise_range = None