Merge pull request #19 from rwightman/resnet-refactor

Resnet refactoring
pull/23/head
Ross Wightman 6 years ago committed by GitHub
commit c11973602d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,8 +18,9 @@ The work of many others is present here. I've tried to make sure all source mate
I've included a few of my favourite models, but this is not an exhaustive collection. You can't do better than Cadene's collection in that regard. Most models do have pretrained weights from their respective sources or original authors.
* ResNet/ResNeXt (from [torchvision](https://github.com/pytorch/vision/tree/master/torchvision/models) with ResNeXt mods by myself)
* ResNet/ResNeXt (from [torchvision](https://github.com/pytorch/vision/tree/master/torchvision/models) with mods by myself)
* ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152, ResNeXt50 (32x4d), ResNeXt101 (32x4d and 64x4d)
* 'Bag of Tricks' / Gluon C, D, E, S variations (https://arxiv.org/abs/1812.01187)
* Instagram trained / ImageNet tuned ResNeXt101-32x8d to 32x48d from from [facebookresearch](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/)
* DenseNet (from [torchvision](https://github.com/pytorch/vision/tree/master/torchvision/models))
* DenseNet-121, DenseNet-169, DenseNet-201, DenseNet-161
@ -70,12 +71,15 @@ I've leveraged the training scripts in this repository to train a few of the mod
#### @ 224x224
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling |
|---|---|---|---|---|
| resnext50d_32x4d | 79.674 (20.326) | 94.868 (5.132) | 25.1M | bicubic |
| resnext50_32x4d | 78.512 (21.488) | 94.042 (5.958) | 25M | bicubic |
| resnet50 | 78.470 (21.530) | 94.266 (5.734) | 25.6M | bicubic |
| seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic |
| efficientnet_b0 | 76.912 (23.088) | 93.210 (6.790) | 5.29M | bicubic |
| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic |
| mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic |
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic |
| resnet26 | 75.292 (24.708) | 92.57 (7.43) | 16M | bicubic |
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear |
| resnet34 | 75.110 (24.890) | 92.284 (7.716) | 22M | bilinear |
| seresnet34 | 74.808 (25.192) | 92.124 (7.876) | 22M | bilinear |
@ -120,8 +124,6 @@ I've leveraged the training scripts in this repository to train a few of the mod
| tf_efficientnet_b0 *tfp | 76.828 (23.172) | 93.226 (6.774) | 5.29 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
| tf_efficientnet_b0 | 76.528 (23.472) | 93.010 (6.990) | 5.29 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) |
| gluon_resnet34_v1b | 74.580 (25.420) | 91.988 (8.012) | 21.80 | bicubic | |
| tflite_semnasnet_100 | 73.086 (26.914) | 91.336 (8.664) | 3.87 | bicubic | [Google TFLite](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet) |
| tflite_mnasnet_100 | 72.398 (27.602) | 90.930 (9.070) | 4.36 | bicubic | [Google TFLite](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet)
| gluon_resnet18_v1b | 70.830 (29.170) | 89.756 (10.244) | 11.69 | bicubic | |
#### @ 240x240

@ -2,8 +2,6 @@ model,top1,top1_err,top5,top5_err,param_count,img_size,cropt_pct,interpolation
resnet18,69.758,30.242,89.078,10.922,11.69,224,0.875,bilinear
gluon_resnet18_v1b,70.83,29.17,89.756,10.244,11.69,224,0.875,bicubic
seresnet18,71.758,28.242,90.334,9.666,11.78,224,0.875,bicubic
tflite_mnasnet_100,72.4,27.6,90.936,9.064,4.36,224,0.875,bicubic
tflite_semnasnet_100,73.078,26.922,91.334,8.666,3.87,224,0.875,bicubic
tv_resnet34,73.314,26.686,91.42,8.58,21.8,224,0.875,bilinear
spnasnet_100,74.08,25.92,91.832,8.168,4.42,224,0.875,bilinear
gluon_resnet34_v1b,74.58,25.42,91.988,8.012,21.8,224,0.875,bicubic
@ -12,12 +10,14 @@ densenet121,74.752,25.248,92.152,7.848,7.98,224,0.875,bicubic
seresnet34,74.808,25.192,92.126,7.874,21.96,224,0.875,bilinear
resnet34,75.112,24.888,92.288,7.712,21.8,224,0.875,bilinear
fbnetc_100,75.12,24.88,92.386,7.614,5.57,224,0.875,bilinear
resnet26,75.292,24.708,92.57,7.43,16,224,0.875,bicubic
semnasnet_100,75.456,24.544,92.592,7.408,3.89,224,0.875,bicubic
mobilenetv3_100,75.628,24.372,92.708,7.292,5.48,224,0.875,bicubic
densenet169,75.912,24.088,93.024,6.976,14.15,224,0.875,bicubic
tv_resnet50,76.13,23.87,92.862,7.138,25.56,224,0.875,bilinear
dpn68,76.306,23.694,92.97,7.03,12.61,224,0.875,bicubic
tf_efficientnet_b0,76.528,23.472,93.01,6.99,5.29,224,0.875,bicubic
resnet26d,76.68,23.32,93.166,6.834,16.01,224,0.875,bicubic
efficientnet_b0,76.914,23.086,93.206,6.794,5.29,224,0.875,bicubic
seresnext26_32x4d,77.1,22.9,93.31,6.69,16.79,224,0.875,bicubic
densenet201,77.29,22.71,93.478,6.522,20.01,224,0.875,bicubic
@ -30,7 +30,7 @@ gluon_resnet50_v1b,77.578,22.422,93.718,6.282,25.56,224,0.875,bicubic
tv_resnext50_32x4d,77.618,22.382,93.698,6.302,25.03,224,0.875,bilinear
seresnet50,77.636,22.364,93.752,6.248,28.09,224,0.875,bilinear
tf_inception_v3,77.856,22.144,93.644,6.356,23.83,299,0.875,bicubic
gluon_resnet50_v1c,78.01,21.99,93.988,6.012,25.58,224,0.875,bicubic
gluon_resnet50_v1c,78.012,21.988,93.988,6.012,25.58,224,0.875,bicubic
resnet152,78.312,21.688,94.046,5.954,60.19,224,0.875,bilinear
seresnet101,78.396,21.604,94.258,5.742,49.33,224,0.875,bilinear
wide_resnet50_2,78.468,21.532,94.086,5.914,68.88,224,0.875,bilinear
@ -51,6 +51,7 @@ gluon_resnext50_32x4d,79.356,20.644,94.424,5.576,25.03,224,0.875,bicubic
gluon_resnet101_v1c,79.544,20.456,94.586,5.414,44.57,224,0.875,bicubic
tf_efficientnet_b2,79.606,20.394,94.712,5.288,9.11,260,0.89,bicubic
dpn98,79.636,20.364,94.594,5.406,61.57,224,0.875,bicubic
resnext50d_32x4d,79.674,20.326,94.868,5.132,25.05,224,0.875,bicubic
gluon_resnet152_v1b,79.692,20.308,94.738,5.262,60.19,224,0.875,bicubic
efficientnet_b2,79.752,20.248,94.71,5.29,9.11,260,0.89,bicubic
dpn131,79.828,20.172,94.704,5.296,79.25,224,0.875,bicubic

1 model top1 top1_err top5 top5_err param_count img_size cropt_pct interpolation
2 resnet18 69.758 30.242 89.078 10.922 11.69 224 0.875 bilinear
3 gluon_resnet18_v1b 70.83 29.17 89.756 10.244 11.69 224 0.875 bicubic
4 seresnet18 71.758 28.242 90.334 9.666 11.78 224 0.875 bicubic
tflite_mnasnet_100 72.4 27.6 90.936 9.064 4.36 224 0.875 bicubic
tflite_semnasnet_100 73.078 26.922 91.334 8.666 3.87 224 0.875 bicubic
5 tv_resnet34 73.314 26.686 91.42 8.58 21.8 224 0.875 bilinear
6 spnasnet_100 74.08 25.92 91.832 8.168 4.42 224 0.875 bilinear
7 gluon_resnet34_v1b 74.58 25.42 91.988 8.012 21.8 224 0.875 bicubic
10 seresnet34 74.808 25.192 92.126 7.874 21.96 224 0.875 bilinear
11 resnet34 75.112 24.888 92.288 7.712 21.8 224 0.875 bilinear
12 fbnetc_100 75.12 24.88 92.386 7.614 5.57 224 0.875 bilinear
13 resnet26 75.292 24.708 92.57 7.43 16 224 0.875 bicubic
14 semnasnet_100 75.456 24.544 92.592 7.408 3.89 224 0.875 bicubic
15 mobilenetv3_100 75.628 24.372 92.708 7.292 5.48 224 0.875 bicubic
16 densenet169 75.912 24.088 93.024 6.976 14.15 224 0.875 bicubic
17 tv_resnet50 76.13 23.87 92.862 7.138 25.56 224 0.875 bilinear
18 dpn68 76.306 23.694 92.97 7.03 12.61 224 0.875 bicubic
19 tf_efficientnet_b0 76.528 23.472 93.01 6.99 5.29 224 0.875 bicubic
20 resnet26d 76.68 23.32 93.166 6.834 16.01 224 0.875 bicubic
21 efficientnet_b0 76.914 23.086 93.206 6.794 5.29 224 0.875 bicubic
22 seresnext26_32x4d 77.1 22.9 93.31 6.69 16.79 224 0.875 bicubic
23 densenet201 77.29 22.71 93.478 6.522 20.01 224 0.875 bicubic
30 tv_resnext50_32x4d 77.618 22.382 93.698 6.302 25.03 224 0.875 bilinear
31 seresnet50 77.636 22.364 93.752 6.248 28.09 224 0.875 bilinear
32 tf_inception_v3 77.856 22.144 93.644 6.356 23.83 299 0.875 bicubic
33 gluon_resnet50_v1c 78.01 78.012 21.99 21.988 93.988 6.012 25.58 224 0.875 bicubic
34 resnet152 78.312 21.688 94.046 5.954 60.19 224 0.875 bilinear
35 seresnet101 78.396 21.604 94.258 5.742 49.33 224 0.875 bilinear
36 wide_resnet50_2 78.468 21.532 94.086 5.914 68.88 224 0.875 bilinear
51 gluon_resnet101_v1c 79.544 20.456 94.586 5.414 44.57 224 0.875 bicubic
52 tf_efficientnet_b2 79.606 20.394 94.712 5.288 9.11 260 0.89 bicubic
53 dpn98 79.636 20.364 94.594 5.406 61.57 224 0.875 bicubic
54 resnext50d_32x4d 79.674 20.326 94.868 5.132 25.05 224 0.875 bicubic
55 gluon_resnet152_v1b 79.692 20.308 94.738 5.262 60.19 224 0.875 bicubic
56 efficientnet_b2 79.752 20.248 94.71 5.29 9.11 260 0.89 bicubic
57 dpn131 79.828 20.172 94.704 5.296 79.25 224 0.875 bicubic

@ -0,0 +1,84 @@
model,top1,top1_err,top5,top5_err,param_count,img_size,cropt_pct,interpolation
resnet18,57.18,42.82,80.19,19.81,11.69,224,0.875,bilinear
gluon_resnet18_v1b,58.32,41.68,80.96,19.04,11.69,224,0.875,bicubic
seresnet18,59.81,40.19,81.68,18.32,11.78,224,0.875,bicubic
tv_resnet34,61.2,38.8,82.72,17.28,21.8,224,0.875,bilinear
spnasnet_100,61.21,38.79,82.77,17.23,4.42,224,0.875,bilinear
mnasnet_100,61.91,38.09,83.71,16.29,4.38,224,0.875,bicubic
fbnetc_100,62.43,37.57,83.39,16.61,5.57,224,0.875,bilinear
gluon_resnet34_v1b,62.56,37.44,84,16,21.8,224,0.875,bicubic
resnet34,62.82,37.18,84.12,15.88,21.8,224,0.875,bilinear
seresnet34,62.89,37.11,84.22,15.78,21.96,224,0.875,bilinear
densenet121,62.94,37.06,84.26,15.74,7.98,224,0.875,bicubic
semnasnet_100,63.12,36.88,84.53,15.47,3.89,224,0.875,bicubic
mobilenetv3_100,63.23,36.77,84.52,15.48,5.48,224,0.875,bicubic
tv_resnet50,63.33,36.67,84.65,15.35,25.56,224,0.875,bilinear
resnet26,63.45,36.55,84.27,15.73,16,224,0.875,bicubic
tf_efficientnet_b0,63.53,36.47,84.88,15.12,5.29,224,0.875,bicubic
dpn68,64.22,35.78,85.18,14.82,12.61,224,0.875,bicubic
efficientnet_b0,64.58,35.42,85.89,14.11,5.29,224,0.875,bicubic
resnet26d,64.63,35.37,85.12,14.88,16.01,224,0.875,bicubic
densenet169,64.78,35.22,85.25,14.75,14.15,224,0.875,bicubic
seresnext26_32x4d,65.04,34.96,85.65,14.35,16.79,224,0.875,bicubic
densenet201,65.28,34.72,85.67,14.33,20.01,224,0.875,bicubic
dpn68b,65.6,34.4,85.94,14.06,12.61,224,0.875,bicubic
resnet101,65.68,34.32,85.98,14.02,44.55,224,0.875,bilinear
densenet161,65.85,34.15,86.46,13.54,28.68,224,0.875,bicubic
gluon_resnet50_v1b,66.04,33.96,86.27,13.73,25.56,224,0.875,bicubic
inception_v3,66.12,33.88,86.34,13.66,27.16,299,0.875,bicubic
tv_resnext50_32x4d,66.18,33.82,86.04,13.96,25.03,224,0.875,bilinear
seresnet50,66.24,33.76,86.33,13.67,28.09,224,0.875,bilinear
tf_inception_v3,66.41,33.59,86.68,13.32,23.83,299,0.875,bicubic
tf_efficientnet_b1,66.52,33.48,86.68,13.32,7.79,240,0.882,bicubic
gluon_resnet50_v1c,66.54,33.46,86.16,13.84,25.58,224,0.875,bicubic
adv_inception_v3,66.6,33.4,86.56,13.44,23.83,299,0.875,bicubic
wide_resnet50_2,66.65,33.35,86.81,13.19,68.88,224,0.875,bilinear
wide_resnet101_2,66.68,33.32,87.04,12.96,126.89,224,0.875,bilinear
resnet50,66.81,33.19,87,13,25.56,224,0.875,bicubic
resnext50_32x4d,66.88,33.12,86.36,13.64,25.03,224,0.875,bicubic
resnet152,67.02,32.98,87.57,12.43,60.19,224,0.875,bilinear
gluon_resnet50_v1s,67.1,32.9,86.86,13.14,25.68,224,0.875,bicubic
seresnet101,67.15,32.85,87.05,12.95,49.33,224,0.875,bilinear
tf_efficientnet_b2,67.4,32.6,87.58,12.42,9.11,260,0.89,bicubic
gluon_resnet101_v1b,67.45,32.55,87.23,12.77,44.55,224,0.875,bicubic
efficientnet_b1,67.55,32.45,87.29,12.71,7.79,240,0.882,bicubic
seresnet152,67.55,32.45,87.39,12.61,66.82,224,0.875,bilinear
gluon_resnet101_v1c,67.56,32.44,87.16,12.84,44.57,224,0.875,bicubic
gluon_inception_v3,67.59,32.41,87.46,12.54,23.83,299,0.875,bicubic
xception,67.67,32.33,87.57,12.43,22.86,299,0.8975,bicubic
efficientnet_b2,67.8,32.2,88.2,11.8,9.11,260,0.89,bicubic
resnext101_32x8d,67.85,32.15,87.48,12.52,88.79,224,0.875,bilinear
seresnext50_32x4d,67.87,32.13,87.62,12.38,27.56,224,0.875,bilinear
gluon_resnet50_v1d,67.91,32.09,87.12,12.88,25.58,224,0.875,bicubic
dpn92,68.01,31.99,87.59,12.41,37.67,224,0.875,bicubic
gluon_resnext50_32x4d,68.28,31.72,87.32,12.68,25.03,224,0.875,bicubic
tf_efficientnet_b3,68.52,31.48,88.7,11.3,12.23,300,0.904,bicubic
dpn98,68.58,31.42,87.66,12.34,61.57,224,0.875,bicubic
gluon_seresnext50_32x4d,68.67,31.33,88.32,11.68,27.56,224,0.875,bicubic
dpn107,68.71,31.29,88.13,11.87,86.92,224,0.875,bicubic
gluon_resnet101_v1s,68.72,31.28,87.9,12.1,44.67,224,0.875,bicubic
resnext50d_32x4d,68.75,31.25,88.31,11.69,25.05,224,0.875,bicubic
dpn131,68.76,31.24,87.48,12.52,79.25,224,0.875,bicubic
gluon_resnet152_v1b,68.81,31.19,87.71,12.29,60.19,224,0.875,bicubic
gluon_resnext101_32x4d,68.96,31.04,88.34,11.66,44.18,224,0.875,bicubic
gluon_resnet101_v1d,68.99,31.01,88.08,11.92,44.57,224,0.875,bicubic
gluon_resnet152_v1c,69.13,30.87,87.89,12.11,60.21,224,0.875,bicubic
seresnext101_32x4d,69.34,30.66,88.05,11.95,48.96,224,0.875,bilinear
inception_v4,69.35,30.65,88.78,11.22,42.68,299,0.875,bicubic
ens_adv_inception_resnet_v2,69.52,30.48,88.5,11.5,55.84,299,0.8975,bicubic
gluon_resnext101_64x4d,69.69,30.31,88.26,11.74,83.46,224,0.875,bicubic
gluon_resnet152_v1d,69.95,30.05,88.47,11.53,60.21,224,0.875,bicubic
gluon_seresnext101_32x4d,70.01,29.99,88.91,11.09,48.96,224,0.875,bicubic
inception_resnet_v2,70.12,29.88,88.68,11.32,55.84,299,0.8975,bicubic
gluon_resnet152_v1s,70.32,29.68,88.87,11.13,60.32,224,0.875,bicubic
gluon_seresnext101_64x4d,70.44,29.56,89.35,10.65,88.23,224,0.875,bicubic
senet154,70.48,29.52,88.99,11.01,115.09,224,0.875,bilinear
gluon_senet154,70.6,29.4,88.92,11.08,115.09,224,0.875,bicubic
tf_efficientnet_b4,71.34,28.66,90.11,9.89,19.34,380,0.922,bicubic
nasnetalarge,72.31,27.69,90.51,9.49,88.75,331,0.875,bicubic
pnasnet5large,72.37,27.63,90.26,9.74,86.06,331,0.875,bicubic
tf_efficientnet_b5,72.56,27.44,91.1,8.9,30.39,456,0.934,bicubic
ig_resnext101_32x8d,73.66,26.34,92.15,7.85,88.79,224,0.875,bilinear
ig_resnext101_32x16d,75.71,24.29,92.9,7.1,194.03,224,0.875,bilinear
ig_resnext101_32x32d,76.84,23.16,93.19,6.81,468.53,224,0.875,bilinear
ig_resnext101_32x48d,76.87,23.13,93.32,6.68,828.41,224,0.875,bilinear
1 model top1 top1_err top5 top5_err param_count img_size cropt_pct interpolation
2 resnet18 57.18 42.82 80.19 19.81 11.69 224 0.875 bilinear
3 gluon_resnet18_v1b 58.32 41.68 80.96 19.04 11.69 224 0.875 bicubic
4 seresnet18 59.81 40.19 81.68 18.32 11.78 224 0.875 bicubic
5 tv_resnet34 61.2 38.8 82.72 17.28 21.8 224 0.875 bilinear
6 spnasnet_100 61.21 38.79 82.77 17.23 4.42 224 0.875 bilinear
7 mnasnet_100 61.91 38.09 83.71 16.29 4.38 224 0.875 bicubic
8 fbnetc_100 62.43 37.57 83.39 16.61 5.57 224 0.875 bilinear
9 gluon_resnet34_v1b 62.56 37.44 84 16 21.8 224 0.875 bicubic
10 resnet34 62.82 37.18 84.12 15.88 21.8 224 0.875 bilinear
11 seresnet34 62.89 37.11 84.22 15.78 21.96 224 0.875 bilinear
12 densenet121 62.94 37.06 84.26 15.74 7.98 224 0.875 bicubic
13 semnasnet_100 63.12 36.88 84.53 15.47 3.89 224 0.875 bicubic
14 mobilenetv3_100 63.23 36.77 84.52 15.48 5.48 224 0.875 bicubic
15 tv_resnet50 63.33 36.67 84.65 15.35 25.56 224 0.875 bilinear
16 resnet26 63.45 36.55 84.27 15.73 16 224 0.875 bicubic
17 tf_efficientnet_b0 63.53 36.47 84.88 15.12 5.29 224 0.875 bicubic
18 dpn68 64.22 35.78 85.18 14.82 12.61 224 0.875 bicubic
19 efficientnet_b0 64.58 35.42 85.89 14.11 5.29 224 0.875 bicubic
20 resnet26d 64.63 35.37 85.12 14.88 16.01 224 0.875 bicubic
21 densenet169 64.78 35.22 85.25 14.75 14.15 224 0.875 bicubic
22 seresnext26_32x4d 65.04 34.96 85.65 14.35 16.79 224 0.875 bicubic
23 densenet201 65.28 34.72 85.67 14.33 20.01 224 0.875 bicubic
24 dpn68b 65.6 34.4 85.94 14.06 12.61 224 0.875 bicubic
25 resnet101 65.68 34.32 85.98 14.02 44.55 224 0.875 bilinear
26 densenet161 65.85 34.15 86.46 13.54 28.68 224 0.875 bicubic
27 gluon_resnet50_v1b 66.04 33.96 86.27 13.73 25.56 224 0.875 bicubic
28 inception_v3 66.12 33.88 86.34 13.66 27.16 299 0.875 bicubic
29 tv_resnext50_32x4d 66.18 33.82 86.04 13.96 25.03 224 0.875 bilinear
30 seresnet50 66.24 33.76 86.33 13.67 28.09 224 0.875 bilinear
31 tf_inception_v3 66.41 33.59 86.68 13.32 23.83 299 0.875 bicubic
32 tf_efficientnet_b1 66.52 33.48 86.68 13.32 7.79 240 0.882 bicubic
33 gluon_resnet50_v1c 66.54 33.46 86.16 13.84 25.58 224 0.875 bicubic
34 adv_inception_v3 66.6 33.4 86.56 13.44 23.83 299 0.875 bicubic
35 wide_resnet50_2 66.65 33.35 86.81 13.19 68.88 224 0.875 bilinear
36 wide_resnet101_2 66.68 33.32 87.04 12.96 126.89 224 0.875 bilinear
37 resnet50 66.81 33.19 87 13 25.56 224 0.875 bicubic
38 resnext50_32x4d 66.88 33.12 86.36 13.64 25.03 224 0.875 bicubic
39 resnet152 67.02 32.98 87.57 12.43 60.19 224 0.875 bilinear
40 gluon_resnet50_v1s 67.1 32.9 86.86 13.14 25.68 224 0.875 bicubic
41 seresnet101 67.15 32.85 87.05 12.95 49.33 224 0.875 bilinear
42 tf_efficientnet_b2 67.4 32.6 87.58 12.42 9.11 260 0.89 bicubic
43 gluon_resnet101_v1b 67.45 32.55 87.23 12.77 44.55 224 0.875 bicubic
44 efficientnet_b1 67.55 32.45 87.29 12.71 7.79 240 0.882 bicubic
45 seresnet152 67.55 32.45 87.39 12.61 66.82 224 0.875 bilinear
46 gluon_resnet101_v1c 67.56 32.44 87.16 12.84 44.57 224 0.875 bicubic
47 gluon_inception_v3 67.59 32.41 87.46 12.54 23.83 299 0.875 bicubic
48 xception 67.67 32.33 87.57 12.43 22.86 299 0.8975 bicubic
49 efficientnet_b2 67.8 32.2 88.2 11.8 9.11 260 0.89 bicubic
50 resnext101_32x8d 67.85 32.15 87.48 12.52 88.79 224 0.875 bilinear
51 seresnext50_32x4d 67.87 32.13 87.62 12.38 27.56 224 0.875 bilinear
52 gluon_resnet50_v1d 67.91 32.09 87.12 12.88 25.58 224 0.875 bicubic
53 dpn92 68.01 31.99 87.59 12.41 37.67 224 0.875 bicubic
54 gluon_resnext50_32x4d 68.28 31.72 87.32 12.68 25.03 224 0.875 bicubic
55 tf_efficientnet_b3 68.52 31.48 88.7 11.3 12.23 300 0.904 bicubic
56 dpn98 68.58 31.42 87.66 12.34 61.57 224 0.875 bicubic
57 gluon_seresnext50_32x4d 68.67 31.33 88.32 11.68 27.56 224 0.875 bicubic
58 dpn107 68.71 31.29 88.13 11.87 86.92 224 0.875 bicubic
59 gluon_resnet101_v1s 68.72 31.28 87.9 12.1 44.67 224 0.875 bicubic
60 resnext50d_32x4d 68.75 31.25 88.31 11.69 25.05 224 0.875 bicubic
61 dpn131 68.76 31.24 87.48 12.52 79.25 224 0.875 bicubic
62 gluon_resnet152_v1b 68.81 31.19 87.71 12.29 60.19 224 0.875 bicubic
63 gluon_resnext101_32x4d 68.96 31.04 88.34 11.66 44.18 224 0.875 bicubic
64 gluon_resnet101_v1d 68.99 31.01 88.08 11.92 44.57 224 0.875 bicubic
65 gluon_resnet152_v1c 69.13 30.87 87.89 12.11 60.21 224 0.875 bicubic
66 seresnext101_32x4d 69.34 30.66 88.05 11.95 48.96 224 0.875 bilinear
67 inception_v4 69.35 30.65 88.78 11.22 42.68 299 0.875 bicubic
68 ens_adv_inception_resnet_v2 69.52 30.48 88.5 11.5 55.84 299 0.8975 bicubic
69 gluon_resnext101_64x4d 69.69 30.31 88.26 11.74 83.46 224 0.875 bicubic
70 gluon_resnet152_v1d 69.95 30.05 88.47 11.53 60.21 224 0.875 bicubic
71 gluon_seresnext101_32x4d 70.01 29.99 88.91 11.09 48.96 224 0.875 bicubic
72 inception_resnet_v2 70.12 29.88 88.68 11.32 55.84 299 0.8975 bicubic
73 gluon_resnet152_v1s 70.32 29.68 88.87 11.13 60.32 224 0.875 bicubic
74 gluon_seresnext101_64x4d 70.44 29.56 89.35 10.65 88.23 224 0.875 bicubic
75 senet154 70.48 29.52 88.99 11.01 115.09 224 0.875 bilinear
76 gluon_senet154 70.6 29.4 88.92 11.08 115.09 224 0.875 bicubic
77 tf_efficientnet_b4 71.34 28.66 90.11 9.89 19.34 380 0.922 bicubic
78 nasnetalarge 72.31 27.69 90.51 9.49 88.75 331 0.875 bicubic
79 pnasnet5large 72.37 27.63 90.26 9.74 86.06 331 0.875 bicubic
80 tf_efficientnet_b5 72.56 27.44 91.1 8.9 30.39 456 0.934 bicubic
81 ig_resnext101_32x8d 73.66 26.34 92.15 7.85 88.79 224 0.875 bilinear
82 ig_resnext101_32x16d 75.71 24.29 92.9 7.1 194.03 224 0.875 bilinear
83 ig_resnext101_32x32d 76.84 23.16 93.19 6.81 468.53 224 0.875 bilinear
84 ig_resnext101_32x48d 76.87 23.13 93.32 6.68 828.41 224 0.875 bilinear

@ -1,11 +1,13 @@
"""Pytorch ResNet implementation w/ tweaks
This file is a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
"""PyTorch ResNet
This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
additional dropout and dynamic global avg/max pool.
ResNext additions added by Ross Wightman
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants added by Ross Wightman
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -33,6 +35,12 @@ default_cfgs = {
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
'resnet34': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'),
'resnet26': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth',
interpolation='bicubic'),
'resnet26d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth',
interpolation='bicubic'),
'resnet50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth',
interpolation='bicubic'),
@ -45,6 +53,9 @@ default_cfgs = {
'resnext50_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d-068914d1.pth',
interpolation='bicubic'),
'resnext50d_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth',
interpolation='bicubic'),
'resnext101_32x4d': _cfg(url=''),
'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'),
'resnext101_64x4d': _cfg(url=''),
@ -56,30 +67,57 @@ default_cfgs = {
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def _get_padding(kernel_size, stride, dilation=1):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
class SEModule(nn.Module):
def __init__(self, channels, reduction_channels):
super(SEModule, self).__init__()
#self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
def forward(self, x):
#x_se = self.avg_pool(x)
x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
x_se = self.fc1(x_se)
x_se = self.relu(x_se)
x_se = self.fc2(x_se)
return x * x_se.sigmoid()
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, drop_rate=0.0):
cardinality=1, base_width=64, use_se=False,
reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d):
super(BasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
assert base_width == 64, 'BasicBlock doest not support changing base width'
first_planes = planes // reduce_first
outplanes = planes * self.expansion
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
dilation=dilation, bias=False)
self.bn1 = norm_layer(first_planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
dilation=previous_dilation, bias=False)
self.bn2 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.downsample = downsample
self.stride = stride
self.drop_rate = drop_rate
self.dilation = dilation
def forward(self, x):
residual = x
@ -87,13 +125,12 @@ class BasicBlock(nn.Module):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
if self.drop_rate > 0.:
out = F.dropout(out, p=self.drop_rate, training=self.training)
out = self.conv2(out)
out = self.bn2(out)
if self.se is not None:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
@ -107,22 +144,27 @@ class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, drop_rate=0.0):
cardinality=1, base_width=64, use_se=False,
reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d):
super(Bottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality)
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
padding=1, groups=cardinality, bias=False)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
first_planes = width // reduce_first
outplanes = planes * self.expansion
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(first_planes)
self.conv2 = nn.Conv2d(
first_planes, width, kernel_size=3, stride=stride,
padding=dilation, dilation=dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width)
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.drop_rate = drop_rate
self.dilation = dilation
def forward(self, x):
residual = x
@ -131,9 +173,6 @@ class Bottleneck(nn.Module):
out = self.bn1(out)
out = self.relu(out)
if self.drop_rate > 0.:
out = F.dropout(out, p=self.drop_rate, training=self.training)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
@ -141,6 +180,9 @@ class Bottleneck(nn.Module):
out = self.conv3(out)
out = self.bn3(out)
if self.se is not None:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
@ -151,26 +193,105 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, in_chans=3,
cardinality=1, base_width=64,
drop_rate=0.0, block_drop_rate=0.0,
global_pool='avg'):
"""ResNet / ResNeXt / SE-ResNeXt / SE-Net
This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that
* have > 1 stride in the 3x3 conv layer of bottleneck
* have conv-bn-act ordering
This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the
'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
ResNet variants:
* normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
* c - 3 layer deep 3x3 stem, stem_width = 32
* d - 3 layer deep 3x3 stem, stem_width = 32, average pool in downsample
* e - 3 layer deep 3x3 stem, stem_width = 64, average pool in downsample
* s - 3 layer deep 3x3 stem, stem_width = 64
ResNeXt
* normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
* same c,d, e, s variants as ResNet can be enabled
SE-ResNeXt
* normal - 7x7 stem, stem_width = 64
* same c, d, e, s variants as ResNet can be enabled
SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,
reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block
Parameters
----------
block : Block
Class for the residual block. Options are BasicBlockGl, BottleneckGl.
layers : list of int
Numbers of layers in each block
num_classes : int, default 1000
Number of classification classes.
in_chans : int, default 3
Number of input (color) channels.
use_se : bool, default False
Enable Squeeze-Excitation module in blocks
cardinality : int, default 1
Number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64
Factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
deep_stem : bool, default False
Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
stem_width : int, default 64
Number of channels in stem convolutions
block_reduce_first: int, default 1
Reduction factor for first convolution output width of residual blocks,
1 for all archs except senets, where 2
down_kernel_size: int, default 1
Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets
avg_down : bool, default False
Whether to use average pooling for projection skip connection between stages/downsample.
dilated : bool, default False
Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
typically used in Semantic Segmentation.
drop_rate : float, default 0.
Dropout probability before classifier, for training
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
"""
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
cardinality=1, base_width=64, stem_width=64, deep_stem=False,
block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False,
norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg'):
self.num_classes = num_classes
self.inplanes = 64
self.inplanes = stem_width * 2 if deep_stem else 64
self.cardinality = cardinality
self.base_width = base_width
self.drop_rate = drop_rate
self.expansion = block.expansion
self.dilated = dilated
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
if deep_stem:
self.conv1 = nn.Sequential(*[
nn.Conv2d(in_chans, stem_width, 3, stride=2, padding=1, bias=False),
norm_layer(stem_width),
nn.ReLU(inplace=True),
nn.Conv2d(stem_width, stem_width, 3, stride=1, padding=1, bias=False),
norm_layer(stem_width),
nn.ReLU(inplace=True),
nn.Conv2d(stem_width, self.inplanes, 3, stride=1, padding=1, bias=False)])
else:
self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], drop_rate=block_drop_rate)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, drop_rate=block_drop_rate)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, drop_rate=block_drop_rate)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, drop_rate=block_drop_rate)
stride_3_4 = 1 if self.dilated else 2
dilation_3 = 2 if self.dilated else 1
dilation_4 = 4 if self.dilated else 1
largs = dict(use_se=use_se, reduce_first=block_reduce_first, norm_layer=norm_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size)
self.layer1 = self._make_layer(block, 64, layers[0], stride=1, **largs)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, **largs)
self.layer3 = self._make_layer(block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, **largs)
self.layer4 = self._make_layer(block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, **largs)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_features = 512 * block.expansion
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
@ -182,18 +303,35 @@ class ResNet(nn.Module):
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
def _make_layer(self, block, planes, blocks, stride=1, drop_rate=0.):
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d):
downsample = None
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, downsample, self.cardinality, self.base_width, drop_rate)]
downsample_padding = _get_padding(down_kernel_size, stride)
downsample_layers = []
conv_stride = stride
if avg_down:
avg_stride = stride if dilation == 1 else 1
conv_stride = 1
downsample_layers = [nn.AvgPool2d(avg_stride, avg_stride, ceil_mode=True, count_include_pad=False)]
downsample_layers += [
nn.Conv2d(self.inplanes, planes * block.expansion, down_kernel_size,
stride=conv_stride, padding=downsample_padding, bias=False),
norm_layer(planes * block.expansion)]
downsample = nn.Sequential(*downsample_layers)
first_dilation = 1 if dilation in (1, 2) else 2
layers = [block(
self.inplanes, planes, stride, downsample,
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
use_se=use_se, dilation=first_dilation, previous_dilation=dilation, norm_layer=norm_layer)]
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, cardinality=self.cardinality, base_width=self.base_width))
layers.append(block(
self.inplanes, planes,
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
use_se=use_se, dilation=dilation, previous_dilation=dilation, norm_layer=norm_layer))
return nn.Sequential(*layers)
@ -257,6 +395,33 @@ def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model
@register_model
def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-26 model.
"""
default_cfg = default_cfgs['resnet26']
model = ResNet(Bottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-26 v1d model.
This is technically a 28 layer ResNet, sticking with 'd' modifier from Gluon for now.
"""
default_cfg = default_cfgs['resnet26d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], stem_width=32, deep_stem=True, avg_down=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-50 model.
@ -362,6 +527,21 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model
@register_model
def resnext50d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample
"""
default_cfg = default_cfgs['resnext50d_32x4d']
model = ResNet(
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
stem_width=32, deep_stem=True, avg_down=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNeXt-101 32x4d model.

Loading…
Cancel
Save