diff --git a/README.md b/README.md index 7ebe9341..7da6e32e 100644 --- a/README.md +++ b/README.md @@ -64,28 +64,6 @@ Bunch of changes: ### Feb 12, 2020 * Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) -### Feb 6, 2020 -* Add RandAugment trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams) - -### Feb 1/2, 2020 -* Port new EfficientNet-B8 (RandAugment) weights, these are different than the B8 AdvProp, different input normalization. -* Update results csv files on all models for ImageNet validation and three other test sets -* Push PyPi package update - -### Jan 31, 2020 -* Update ResNet50 weights with a new 79.038 result from further JSD / AugMix experiments. Full command line for reproduction in training section below. - -### Jan 11/12, 2020 -* Master may be a bit unstable wrt to training, these changes have been tested but not all combos -* Implementations of AugMix added to existing RA and AA. Including numerous supporting pieces like JSD loss (Jensen-Shannon divergence + CE), and AugMixDataset -* SplitBatchNorm adaptation layer added for implementing Auxiliary BN as per AdvProp paper -* ResNet-50 AugMix trained model w/ 79% top-1 added -* `seresnext26tn_32x4d` - 77.99 top-1, 93.75 top-5 added to tiered experiment, higher img/s than 't' and 'd' - -### Jan 3, 2020 -* Add RandAugment trained EfficientNet-B0 weight with 77.7 top-1. Trained by [Michael Klachko](https://github.com/michaelklachko) with this code and recent hparams (see Training section) -* Add `avg_checkpoints.py` script for post training weight averaging and update all scripts with header docstrings and shebangs. - ## Introduction For each competition, personal, or freelance project involving images + Convolution Neural Networks, I build on top of an evolving collection of code and models. This repo contains a (somewhat) cleaned up and paired down iteration of that code. Hopefully it'll be of use to others. @@ -119,6 +97,7 @@ Included models: * DenseNet-121, DenseNet-169, DenseNet-201, DenseNet-161 * Squeeze-and-Excitation ResNet/ResNeXt (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch) with some pretrained weight additions by myself) * SENet-154, SE-ResNet-18, SE-ResNet-34, SE-ResNet-50, SE-ResNet-101, SE-ResNet-152, SE-ResNeXt-26 (32x4d), SE-ResNeXt50 (32x4d), SE-ResNeXt101 (32x4d) +* Inception-V3 (from [torchvision](https://github.com/pytorch/vision/tree/master/torchvision/models)) * Inception-ResNet-V2 and Inception-V4 (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch) ) * Xception * Original variant from [Cadene](https://github.com/Cadene/pretrained-models.pytorch) @@ -143,6 +122,12 @@ Included models: * code from https://github.com/mehtadushy/SelecSLS-Pytorch, paper https://arxiv.org/abs/1907.00837 * TResNet * code from https://github.com/mrT23/TResNet, paper https://arxiv.org/abs/2003.13630 +* RegNet + * paper `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678 + * reference code at https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py +* VovNet V2 (with V1 support) + * paper `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + * reference code at https://github.com/youngwanLEE/vovnet-detectron2 Use the `--model` arg to specify model for train, validation, inference scripts. Match the all lowercase creation fn for the model you'd like. @@ -187,6 +172,8 @@ I've leveraged the training scripts in this repository to train a few of the mod | skresnext50d_32x4d | 80.156 (19.844) | 94.642 (5.358) | 27.5M | bicubic | 224 | | resnext50_32x4d | 79.762 (20.238) | 94.600 (5.400) | 25M | bicubic | 224 | | resnext50d_32x4d | 79.674 (20.326) | 94.868 (5.132) | 25.1M | bicubic | 224 | +| ese_vovnet39b | 79.320 (20.680) | 94.710 (5.290) | 24.6M | bicubic | 224 | +| resnetblur50 | 79.290 (20.710) | 94.632 (5.368) | 25.6M | bicubic | 224 | | resnet50 | 79.038 (20.962) | 94.390 (5.610) | 25.6M | bicubic | 224 | | mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33M | bicubic | 224 | | efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.79M | bicubic | 240 | @@ -200,6 +187,7 @@ I've leveraged the training scripts in this repository to train a few of the mod | seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic | 224 | | skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2M | bicubic | 224 | | resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | 224 | +| densenetblur121d | 76.576 (23.424) | 93.190 (6.810) | 8.0M | bicubic | 224 | | mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1M | bicubic | 224 | | mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13M | bicubic | 224 | | mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5M | bicubic | 224 | diff --git a/sotabench.py b/sotabench.py index 1d7a0590..93c15d76 100644 --- a/sotabench.py +++ b/sotabench.py @@ -396,6 +396,24 @@ model_list = [ _entry('selecsls60b', 'SelecSLS-60_B', '1907.00837', model_desc='Originally from https://github.com/mehtadushy/SelecSLS-Pytorch'), + ## ResNeSt official impl weights + _entry('resnest14d', 'ResNeSt-14', '2004.08955', + model_desc='Originally from GluonCV'), + _entry('resnest26d', 'ResNeSt-26', '2004.08955', + model_desc='Originally from GluonCV'), + _entry('resnest50d', 'ResNeSt-50', '2004.08955', + model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'), + _entry('resnest101e', 'ResNeSt-101', '2004.08955', + model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'), + _entry('resnest200e', 'ResNeSt-200', '2004.08955', + model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'), + _entry('resnest269e', 'ResNeSt-269', '2004.08955', batch_size=BATCH_SIZE // 2, + model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'), + _entry('resnest50d_4s2x40d', 'ResNeSt-50 4s2x40d', '2004.08955', + model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'), + _entry('resnest50d_1s4x24d', 'ResNeSt-50 1s4x24d', '2004.08955', + model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'), + ## RegNet official impl weighs _entry('regnetx_002', 'RegNetX-200MF', '2003.13678'), _entry('regnetx_004', 'RegNetX-400MF', '2003.13678'), diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py index 8129459b..955178ad 100644 --- a/timm/scheduler/plateau_lr.py +++ b/timm/scheduler/plateau_lr.py @@ -16,7 +16,12 @@ class PlateauLRScheduler(Scheduler): warmup_t=0, warmup_lr_init=0, lr_min=0, - mode='min', + mode='max', + noise_range_t=None, + noise_type='normal', + noise_pct=0.67, + noise_std=1.0, + noise_seed=None, initialize=True, ): super().__init__(optimizer, 'lr', initialize=initialize) @@ -32,6 +37,11 @@ class PlateauLRScheduler(Scheduler): min_lr=lr_min ) + self.noise_range = noise_range_t + self.noise_pct = noise_pct + self.noise_type = noise_type + self.noise_std = noise_std + self.noise_seed = noise_seed if noise_seed is not None else 42 self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init if self.warmup_t: @@ -39,6 +49,7 @@ class PlateauLRScheduler(Scheduler): super().update_groups(self.warmup_lr_init) else: self.warmup_steps = [1 for _ in self.base_values] + self.restore_lr = None def state_dict(self): return { @@ -57,4 +68,40 @@ class PlateauLRScheduler(Scheduler): lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps] super().update_groups(lrs) else: - self.lr_scheduler.step(metric, epoch) + if self.restore_lr is not None: + # restore actual LR from before our last noise perturbation before stepping base + for i, param_group in enumerate(self.optimizer.param_groups): + param_group['lr'] = self.restore_lr[i] + self.restore_lr = None + + self.lr_scheduler.step(metric, epoch) # step the base scheduler + + if self.noise_range is not None: + if isinstance(self.noise_range, (list, tuple)): + apply_noise = self.noise_range[0] <= epoch < self.noise_range[1] + else: + apply_noise = epoch >= self.noise_range + if apply_noise: + self._apply_noise(epoch) + + def _apply_noise(self, epoch): + g = torch.Generator() + g.manual_seed(self.noise_seed + epoch) + if self.noise_type == 'normal': + while True: + # resample if noise out of percent limit, brute force but shouldn't spin much + noise = torch.randn(1, generator=g).item() + if abs(noise) < self.noise_pct: + break + else: + noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct + + # apply the noise on top of previous LR, cache the old value so we can restore for normal + # stepping of base scheduler + restore_lr = [] + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group['lr']) + restore_lr.append(old_lr) + new_lr = old_lr + old_lr * noise + param_group['lr'] = new_lr + self.restore_lr = restore_lr diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index ee4220ec..b058e3d2 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -7,49 +7,49 @@ from .plateau_lr import PlateauLRScheduler def create_scheduler(args, optimizer): num_epochs = args.epochs - if args.lr_noise is not None: - if isinstance(args.lr_noise, (list, tuple)): - noise_range = [n * num_epochs for n in args.lr_noise] + if getattr(args, 'lr_noise', None) is not None: + lr_noise = getattr(args, 'lr_noise') + if isinstance(lr_noise, (list, tuple)): + noise_range = [n * num_epochs for n in lr_noise] if len(noise_range) == 1: noise_range = noise_range[0] else: - noise_range = args.lr_noise * num_epochs + noise_range = lr_noise * num_epochs else: noise_range = None lr_scheduler = None - #FIXME expose cycle parms of the scheduler config to arguments if args.sched == 'cosine': lr_scheduler = CosineLRScheduler( optimizer, t_initial=num_epochs, - t_mul=args.lr_cycle_mul, + t_mul=getattr(args, 'lr_cycle_mul', 1.), lr_min=args.min_lr, decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, - cycle_limit=args.lr_cycle_limit, + cycle_limit=getattr(args, 'lr_cycle_limit', 0), t_in_epochs=True, noise_range_t=noise_range, - noise_pct=args.lr_noise_pct, - noise_std=args.lr_noise_std, - noise_seed=args.seed, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), ) num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs elif args.sched == 'tanh': lr_scheduler = TanhLRScheduler( optimizer, t_initial=num_epochs, - t_mul=args.lr_cycle_mul, + t_mul=getattr(args, 'lr_cycle_mul', 1.), lr_min=args.min_lr, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, - cycle_limit=args.lr_cycle_limit, + cycle_limit=getattr(args, 'lr_cycle_limit', 0), t_in_epochs=True, noise_range_t=noise_range, - noise_pct=args.lr_noise_pct, - noise_std=args.lr_noise_std, - noise_seed=args.seed, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), ) num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs elif args.sched == 'step': @@ -60,19 +60,25 @@ def create_scheduler(args, optimizer): warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, noise_range_t=noise_range, - noise_pct=args.lr_noise_pct, - noise_std=args.lr_noise_std, - noise_seed=args.seed, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), ) elif args.sched == 'plateau': + mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' lr_scheduler = PlateauLRScheduler( optimizer, decay_rate=args.decay_rate, patience_t=args.patience_epochs, lr_min=args.min_lr, + mode=mode, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, - cooldown_t=args.cooldown_epochs, + cooldown_t=0, + noise_range_t=noise_range, + noise_pct=getattr(args, 'lr_noise_pct', 0.67), + noise_std=getattr(args, 'lr_noise_std', 1.), + noise_seed=getattr(args, 'seed', 42), ) return lr_scheduler, num_epochs