Add instagram pretrained ResNeXt models from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/, update README

pull/16/head
Ross Wightman 5 years ago
parent 1202d053bc
commit 8512436436

@ -20,6 +20,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec
* ResNet/ResNeXt (from [torchvision](https://github.com/pytorch/vision/tree/master/torchvision/models) with ResNeXt mods by myself)
* ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152, ResNeXt50 (32x4d), ResNeXt101 (32x4d and 64x4d)
* Instagram trained / ImageNet tuned ResNeXt101-32x8d to 32x48d from from [facebookresearch](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/)
* DenseNet (from [torchvision](https://github.com/pytorch/vision/tree/master/torchvision/models))
* DenseNet-121, DenseNet-169, DenseNet-201, DenseNet-161
* Squeeze-and-Excitation ResNet/ResNeXt (from [Cadene](https://github.com/Cadene/pretrained-models.pytorch) with some pretrained weight additions by myself)
@ -141,16 +142,13 @@ I've leveraged the training scripts in this repository to train a few of the mod
| tf_efficientnet_b5 *tfp | 83.200 (16.800) | 96.456 (3.544) | 30.39 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
| tf_efficientnet_b5 | 83.176 (16.824) | 96.536 (3.464) | 30.39 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
NOTE: For some reason I can't hit the stated accuracy with my impl of MNASNet and Google's tflite weights. Using a TF equivalent to 'SAME' padding was important to get > 70%, but something small is still missing. Trying to train my own weights from scratch with these models has so far to leveled off in the same 72-73% range.
Models with `*tfp` next to them were scored with `--tf-preprocessing` flag.
The `tf_efficientnet` and `tflite_(se)mnasnet` models require an equivalent for 'SAME' padding as their arch results in asymmetric padding. I've added this in the model creation wrapper, but it does come with a performance penalty.
## Script Usage
## Usage
## Environment
### Environment
All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x and 3.7.x. Little to no care has been taken to be Python 2.x friendly and I don't plan to support it. If you run into any challenges running on Windows, or other OS, I'm definitely open to looking into those issues so long as it's in a reproducible (read Conda) environment.
@ -163,7 +161,25 @@ conda activate torch-env
conda install -c pytorch pytorch torchvision cudatoolkit=10.0
```
### Training
### Pip
This package can be installed via pip. Currently, the model factory (`timm.create_model`) is the most useful component to use via a pip install.
Install (after conda env/install):
```
pip install timm
```
Use:
```
>>> import timm
>>> m = timm.create_model('mobilenetv3_100', pretrained=True)
>>> m.eval()
```
### Scripts
A train, validation, inference, and checkpoint cleaning script included in the github root folder. Scripts are not currently packaged in the pip release.
#### Training
The variety of training args is large and not all combinations of options (or even options) have been fully tested. For the training dataset folder, specify the folder to the base that contains a `train` and `validation` folder.
@ -173,7 +189,7 @@ To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process pe
NOTE: NVIDIA APEX should be installed to run in per-process distributed via DDP or to enable AMP mixed precision with the --amp flag
### Validation / Inference
#### Validation / Inference
Validation and inference scripts are similar in usage. One outputs metrics on a validation set and the other outputs topk class ids in a csv. Specify the folder containing validation images, not the base as in training script.

@ -12,7 +12,8 @@ from .adaptive_avgmax_pool import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
_models = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d']
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d',
'ig_resnext101_32x8d', 'ig_resnext101_32x16d', 'ig_resnext101_32x32d', 'ig_resnext101_32x48d']
__all__ = ['ResNet'] + _models
@ -39,6 +40,10 @@ default_cfgs = {
'resnext101_32x4d': _cfg(url=''),
'resnext101_64x4d': _cfg(url=''),
'resnext152_32x4d': _cfg(url=''),
'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'),
'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'),
'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'),
'ig_resnext101_32x48d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth'),
}
@ -324,3 +329,75 @@ def resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def ig_resnext101_32x8d(pretrained=True, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data
and finetuned on ImageNet from Figure 5 in
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
Args:
pretrained (bool): load pretrained weights
num_classes (int): number of classes for classifier (default: 1000 for pretrained)
in_chans (int): number of input planes (default: 3 for pretrained / color)
"""
default_cfg = default_cfgs['ig_resnext101_32x8d']
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def ig_resnext101_32x16d(pretrained=True, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data
and finetuned on ImageNet from Figure 5 in
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
Args:
pretrained (bool): load pretrained weights
num_classes (int): number of classes for classifier (default: 1000 for pretrained)
in_chans (int): number of input planes (default: 3 for pretrained / color)
"""
default_cfg = default_cfgs['ig_resnext101_32x16d']
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def ig_resnext101_32x32d(pretrained=True, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data
and finetuned on ImageNet from Figure 5 in
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
Args:
pretrained (bool): load pretrained weights
num_classes (int): number of classes for classifier (default: 1000 for pretrained)
in_chans (int): number of input planes (default: 3 for pretrained / color)
"""
default_cfg = default_cfgs['ig_resnext101_32x32d']
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=32, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
def ig_resnext101_32x48d(pretrained=True, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data
and finetuned on ImageNet from Figure 5 in
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
Args:
pretrained (bool): load pretrained weights
num_classes (int): number of classes for classifier (default: 1000 for pretrained)
in_chans (int): number of input planes (default: 3 for pretrained / color)
"""
default_cfg = default_cfgs['ig_resnext101_32x48d']
model = ResNet(Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=48, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

Loading…
Cancel
Save