Update README with model results and attribution. Make scheduler factory bit more robust to arg differences, add noise to plateau lr and fix min/max.

pull/179/head
Ross Wightman 4 years ago
parent d1b5dddad1
commit f225ae8e59

@ -64,28 +64,6 @@ Bunch of changes:
### Feb 12, 2020
* Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
### Feb 6, 2020
* Add RandAugment trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
### Feb 1/2, 2020
* Port new EfficientNet-B8 (RandAugment) weights, these are different than the B8 AdvProp, different input normalization.
* Update results csv files on all models for ImageNet validation and three other test sets
* Push PyPi package update
### Jan 31, 2020
* Update ResNet50 weights with a new 79.038 result from further JSD / AugMix experiments. Full command line for reproduction in training section below.
### Jan 11/12, 2020
* Master may be a bit unstable wrt to training, these changes have been tested but not all combos
* Implementations of AugMix added to existing RA and AA. Including numerous supporting pieces like JSD loss (Jensen-Shannon divergence + CE), and AugMixDataset
* SplitBatchNorm adaptation layer added for implementing Auxiliary BN as per AdvProp paper
* ResNet-50 AugMix trained model w/ 79% top-1 added
* `seresnext26tn_32x4d` - 77.99 top-1, 93.75 top-5 added to tiered experiment, higher img/s than 't' and 'd'
### Jan 3, 2020
* Add RandAugment trained EfficientNet-B0 weight with 77.7 top-1. Trained by [Michael Klachko](https://github.com/michaelklachko) with this code and recent hparams (see Training section)
* Add `avg_checkpoints.py` script for post training weight averaging and update all scripts with header docstrings and shebangs.
## Introduction
For each competition, personal, or freelance project involving images + Convolution Neural Networks, I build on top of an evolving collection of code and models. This repo contains a (somewhat) cleaned up and paired down iteration of that code. Hopefully it'll be of use to others.
@ -119,6 +97,7 @@ Included models:
* DenseNet-121, DenseNet-169, DenseNet-201, DenseNet-161
* Squeeze-and-Excitation ResNet/ResNeXt (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch) with some pretrained weight additions by myself)
* SENet-154, SE-ResNet-18, SE-ResNet-34, SE-ResNet-50, SE-ResNet-101, SE-ResNet-152, SE-ResNeXt-26 (32x4d), SE-ResNeXt50 (32x4d), SE-ResNeXt101 (32x4d)
* Inception-V3 (from [torchvision](https://github.com/pytorch/vision/tree/master/torchvision/models))
* Inception-ResNet-V2 and Inception-V4 (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch) )
* Xception
* Original variant from [Cadene](https://github.com/Cadene/pretrained-models.pytorch)
@ -143,6 +122,12 @@ Included models:
* code from https://github.com/mehtadushy/SelecSLS-Pytorch, paper https://arxiv.org/abs/1907.00837
* TResNet
* code from https://github.com/mrT23/TResNet, paper https://arxiv.org/abs/2003.13630
* RegNet
* paper `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678
* reference code at https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
* VovNet V2 (with V1 support)
* paper `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
* reference code at https://github.com/youngwanLEE/vovnet-detectron2
Use the `--model` arg to specify model for train, validation, inference scripts. Match the all lowercase
creation fn for the model you'd like.
@ -187,6 +172,8 @@ I've leveraged the training scripts in this repository to train a few of the mod
| skresnext50d_32x4d | 80.156 (19.844) | 94.642 (5.358) | 27.5M | bicubic | 224 |
| resnext50_32x4d | 79.762 (20.238) | 94.600 (5.400) | 25M | bicubic | 224 |
| resnext50d_32x4d | 79.674 (20.326) | 94.868 (5.132) | 25.1M | bicubic | 224 |
| ese_vovnet39b | 79.320 (20.680) | 94.710 (5.290) | 24.6M | bicubic | 224 |
| resnetblur50 | 79.290 (20.710) | 94.632 (5.368) | 25.6M | bicubic | 224 |
| resnet50 | 79.038 (20.962) | 94.390 (5.610) | 25.6M | bicubic | 224 |
| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33M | bicubic | 224 |
| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.79M | bicubic | 240 |
@ -200,6 +187,7 @@ I've leveraged the training scripts in this repository to train a few of the mod
| seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic | 224 |
| skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2M | bicubic | 224 |
| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | 224 |
| densenetblur121d | 76.576 (23.424) | 93.190 (6.810) | 8.0M | bicubic | 224 |
| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1M | bicubic | 224 |
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13M | bicubic | 224 |
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5M | bicubic | 224 |

@ -396,6 +396,24 @@ model_list = [
_entry('selecsls60b', 'SelecSLS-60_B', '1907.00837',
model_desc='Originally from https://github.com/mehtadushy/SelecSLS-Pytorch'),
## ResNeSt official impl weights
_entry('resnest14d', 'ResNeSt-14', '2004.08955',
model_desc='Originally from GluonCV'),
_entry('resnest26d', 'ResNeSt-26', '2004.08955',
model_desc='Originally from GluonCV'),
_entry('resnest50d', 'ResNeSt-50', '2004.08955',
model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'),
_entry('resnest101e', 'ResNeSt-101', '2004.08955',
model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'),
_entry('resnest200e', 'ResNeSt-200', '2004.08955',
model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'),
_entry('resnest269e', 'ResNeSt-269', '2004.08955', batch_size=BATCH_SIZE // 2,
model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'),
_entry('resnest50d_4s2x40d', 'ResNeSt-50 4s2x40d', '2004.08955',
model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'),
_entry('resnest50d_1s4x24d', 'ResNeSt-50 1s4x24d', '2004.08955',
model_desc='Originally from https://github.com/zhanghang1989/ResNeSt'),
## RegNet official impl weighs
_entry('regnetx_002', 'RegNetX-200MF', '2003.13678'),
_entry('regnetx_004', 'RegNetX-400MF', '2003.13678'),

@ -16,7 +16,12 @@ class PlateauLRScheduler(Scheduler):
warmup_t=0,
warmup_lr_init=0,
lr_min=0,
mode='min',
mode='max',
noise_range_t=None,
noise_type='normal',
noise_pct=0.67,
noise_std=1.0,
noise_seed=None,
initialize=True,
):
super().__init__(optimizer, 'lr', initialize=initialize)
@ -32,6 +37,11 @@ class PlateauLRScheduler(Scheduler):
min_lr=lr_min
)
self.noise_range = noise_range_t
self.noise_pct = noise_pct
self.noise_type = noise_type
self.noise_std = noise_std
self.noise_seed = noise_seed if noise_seed is not None else 42
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
if self.warmup_t:
@ -39,6 +49,7 @@ class PlateauLRScheduler(Scheduler):
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
self.restore_lr = None
def state_dict(self):
return {
@ -57,4 +68,40 @@ class PlateauLRScheduler(Scheduler):
lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
super().update_groups(lrs)
else:
self.lr_scheduler.step(metric, epoch)
if self.restore_lr is not None:
# restore actual LR from before our last noise perturbation before stepping base
for i, param_group in enumerate(self.optimizer.param_groups):
param_group['lr'] = self.restore_lr[i]
self.restore_lr = None
self.lr_scheduler.step(metric, epoch) # step the base scheduler
if self.noise_range is not None:
if isinstance(self.noise_range, (list, tuple)):
apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
else:
apply_noise = epoch >= self.noise_range
if apply_noise:
self._apply_noise(epoch)
def _apply_noise(self, epoch):
g = torch.Generator()
g.manual_seed(self.noise_seed + epoch)
if self.noise_type == 'normal':
while True:
# resample if noise out of percent limit, brute force but shouldn't spin much
noise = torch.randn(1, generator=g).item()
if abs(noise) < self.noise_pct:
break
else:
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
# apply the noise on top of previous LR, cache the old value so we can restore for normal
# stepping of base scheduler
restore_lr = []
for i, param_group in enumerate(self.optimizer.param_groups):
old_lr = float(param_group['lr'])
restore_lr.append(old_lr)
new_lr = old_lr + old_lr * noise
param_group['lr'] = new_lr
self.restore_lr = restore_lr

@ -7,49 +7,49 @@ from .plateau_lr import PlateauLRScheduler
def create_scheduler(args, optimizer):
num_epochs = args.epochs
if args.lr_noise is not None:
if isinstance(args.lr_noise, (list, tuple)):
noise_range = [n * num_epochs for n in args.lr_noise]
if getattr(args, 'lr_noise', None) is not None:
lr_noise = getattr(args, 'lr_noise')
if isinstance(lr_noise, (list, tuple)):
noise_range = [n * num_epochs for n in lr_noise]
if len(noise_range) == 1:
noise_range = noise_range[0]
else:
noise_range = args.lr_noise * num_epochs
noise_range = lr_noise * num_epochs
else:
noise_range = None
lr_scheduler = None
#FIXME expose cycle parms of the scheduler config to arguments
if args.sched == 'cosine':
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=args.lr_cycle_mul,
t_mul=getattr(args, 'lr_cycle_mul', 1.),
lr_min=args.min_lr,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=args.lr_cycle_limit,
cycle_limit=getattr(args, 'lr_cycle_limit', 0),
t_in_epochs=True,
noise_range_t=noise_range,
noise_pct=args.lr_noise_pct,
noise_std=args.lr_noise_std,
noise_seed=args.seed,
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
noise_std=getattr(args, 'lr_noise_std', 1.),
noise_seed=getattr(args, 'seed', 42),
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
elif args.sched == 'tanh':
lr_scheduler = TanhLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=args.lr_cycle_mul,
t_mul=getattr(args, 'lr_cycle_mul', 1.),
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=args.lr_cycle_limit,
cycle_limit=getattr(args, 'lr_cycle_limit', 0),
t_in_epochs=True,
noise_range_t=noise_range,
noise_pct=args.lr_noise_pct,
noise_std=args.lr_noise_std,
noise_seed=args.seed,
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
noise_std=getattr(args, 'lr_noise_std', 1.),
noise_seed=getattr(args, 'seed', 42),
)
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
elif args.sched == 'step':
@ -60,19 +60,25 @@ def create_scheduler(args, optimizer):
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
noise_range_t=noise_range,
noise_pct=args.lr_noise_pct,
noise_std=args.lr_noise_std,
noise_seed=args.seed,
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
noise_std=getattr(args, 'lr_noise_std', 1.),
noise_seed=getattr(args, 'seed', 42),
)
elif args.sched == 'plateau':
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
lr_scheduler = PlateauLRScheduler(
optimizer,
decay_rate=args.decay_rate,
patience_t=args.patience_epochs,
lr_min=args.min_lr,
mode=mode,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cooldown_t=args.cooldown_epochs,
cooldown_t=0,
noise_range_t=noise_range,
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
noise_std=getattr(args, 'lr_noise_std', 1.),
noise_seed=getattr(args, 'seed', 42),
)
return lr_scheduler, num_epochs

Loading…
Cancel
Save