Add MobileNetV3 Large weights, results, update README and sotabench for merge

pull/94/head
Ross Wightman 5 years ago
parent 9fee316752
commit c16f25ced2

@ -2,6 +2,14 @@
## What's New
### Feb 29, 2020
* New MobileNet-V3 Large weights trained from stratch with this code to 75.77% top-1
* IMPORTANT CHANGE - default weight init changed for all MobilenetV3 / EfficientNet / related models
* overall results similar to a bit better training from scratch on a few smaller models tried
* performance early in training seems consistently improved but less difference by end
* set `fix_group_fanout=False` in `_init_weight_goog` fn if you need to reproducte past behaviour
* Experimental LR noise feature added applies a random perturbation to LR each epoch in specified range of training
### Feb 18, 2020
* Big refactor of model layers and addition of several attention mechanisms. Several additions motivated by 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268):
* Move layer/module impl into `layers` subfolder/module of `models` and organize in a more granular fashion
@ -187,7 +195,8 @@ I've leveraged the training scripts in this repository to train a few of the mod
| skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2M | bicubic | 224 |
| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | 224 |
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13M | bicubic | 224 |
| mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 |
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5M | bicubic | 224 |
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 |
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic | 224 |
| resnet26 | 75.292 (24.708) | 92.57 (7.43) | 16M | bicubic | 224 |
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear | 224 |
@ -361,6 +370,11 @@ Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model
`./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064`
### MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5
`./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9`
**TODO dig up some more**

@ -93,7 +93,7 @@ model_list = [
_entry('semnasnet_100', 'MnasNet-A1', '1807.11626'),
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877',
model_desc='Trained in PyTorch with SGD, cosine LR decay'),
_entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244',
_entry('mobilenetv3_large_100', 'MobileNet V3-Large 1.0', '1905.02244',
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
'paper as closely as possible.'),

@ -31,7 +31,9 @@ def _cfg(url='', **kwargs):
default_cfgs = {
'mobilenetv3_large_075': _cfg(url=''),
'mobilenetv3_large_100': _cfg(url=''),
'mobilenetv3_large_100': _cfg(
interpolation='bicubic',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),
'mobilenetv3_small_075': _cfg(url=''),
'mobilenetv3_small_100': _cfg(url=''),
'mobilenetv3_rw': _cfg(

@ -10,9 +10,10 @@ def create_scheduler(args, optimizer):
if args.lr_noise is not None:
if isinstance(args.lr_noise, (list, tuple)):
noise_range = [n * num_epochs for n in args.lr_noise]
if len(noise_range) == 1:
noise_range = noise_range[0]
else:
noise_range = args.lr_noise * num_epochs
print('Noise range:', noise_range)
else:
noise_range = None

Loading…
Cancel
Save