Add some info to README

pull/2/head
Ross Wightman 5 years ago
parent db8ad25a23
commit e3377b0409

@ -36,6 +36,62 @@ I've included a few of my favourite models, but this is not an exhaustive collec
* MobileNet-V3 (work in progress, validating config)
* ChamNet (details hard to find, currently an educated guess)
* FBNet-C (TODO A/B variants)
The full list of model strings that can be passed to model factory via `--model` arg for train, validation, inference scripts:
```
chamnetv1_100
chamnetv2_100
densenet121
densenet161
densenet169
densenet201
dpn107
dpn131
dpn68
dpn68b
dpn92
dpn98
fbnetc_100
inception_resnet_v2
inception_v4
mnasnet_050
mnasnet_075
mnasnet_100
mnasnet_140
mnasnet_small
mobilenetv1_100
mobilenetv2_100
mobilenetv3_050
mobilenetv3_075
mobilenetv3_100
pnasnet5large
resnet101
resnet152
resnet18
resnet34
resnet50
resnext101_32x4d
resnext101_64x4d
resnext152_32x4d
resnext50_32x4d
semnasnet_050
semnasnet_075
semnasnet_100
semnasnet_140
seresnet101
seresnet152
seresnet18
seresnet34
seresnet50
seresnext101_32x4d
seresnext26_32x4d
seresnext50_32x4d
spnasnet_100
tflite_mnasnet_100
tflite_semnasnet_100
xception
```
## Features
Several (less common) features that I often utilize in my projects are included. Many of their additions are the reason why I maintain my own set of models, instead of using others' via PIP:
* All models have a common default configuration interface and API for
@ -72,6 +128,30 @@ I've leveraged the training scripts in this repository to train a few of the mod
NOTE: For some reason I can't hit the stated accuracy with my impl of MNASNet and Google's tflite weights. Using a TF equivalent to 'SAME' padding was important to get > 70%, but something small is still missing. Trying to train my own weights from scratch with these models has so far to leveled off in the same 72-73% range.
## Script Usage
### Training
The variety of training args is large and not all combinations of options (or even options) have been fully tested. For the training dataset folder, specify the folder to the base that contains a `train` and `validation` folder.
To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value:
`./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --repp --batch-size 256 -j 4`
NOTE: NVIDIA APEX should be installed to run in per-process distributed via DDP or to enable AMP mixed precision with the --amp flag
### Validation / Inference
Validation and inference scripts are similar in usage. One outputs metrics on a validation set and the other outputs topk class ids in a csv. Specify the folder containing validation images, not the base as in training script.
To validate with the model's pretrained weights (if they exist):
`python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained`
To run inference from a checkpoint:
`python inference.py /imagenet/validation/ --model mobilenetv3_100 --checkpoint ./output/model_best.pth.tar`
## TODO
A number of additions planned in the future for various projects, incl
* Find optimal training hyperparams and create/port pretraiend weights for the generic MobileNet variants

Loading…
Cancel
Save