diff --git a/README.md b/README.md index 3160bba3..ff534700 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,9 @@ The work of many others is present here. I've tried to make sure all source mate I've included a few of my favourite models, but this is not an exhaustive collection. You can't do better than Cadene's collection in that regard. Most models do have pretrained weights from their respective sources or original authors. -* ResNet/ResNeXt (from [torchvision](https://github.com/pytorch/vision/tree/master/torchvision/models) with ResNeXt mods by myself) +* ResNet/ResNeXt (from [torchvision](https://github.com/pytorch/vision/tree/master/torchvision/models) with mods by myself) * ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152, ResNeXt50 (32x4d), ResNeXt101 (32x4d and 64x4d) + * 'Bag of Tricks' / Gluon C, D, E, S variations (https://arxiv.org/abs/1812.01187) * Instagram trained / ImageNet tuned ResNeXt101-32x8d to 32x48d from from [facebookresearch](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/) * DenseNet (from [torchvision](https://github.com/pytorch/vision/tree/master/torchvision/models)) * DenseNet-121, DenseNet-169, DenseNet-201, DenseNet-161 @@ -70,12 +71,15 @@ I've leveraged the training scripts in this repository to train a few of the mod #### @ 224x224 |Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | |---|---|---|---|---| +| resnext50d_32x4d | 79.674 (20.326) | 94.868 (5.132) | 25.1M | bicubic | | resnext50_32x4d | 78.512 (21.488) | 94.042 (5.958) | 25M | bicubic | | resnet50 | 78.470 (21.530) | 94.266 (5.734) | 25.6M | bicubic | | seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic | | efficientnet_b0 | 76.912 (23.088) | 93.210 (6.790) | 5.29M | bicubic | +| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | | mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | | mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic | +| resnet26 | 75.292 (24.708) | 92.57 (7.43) | 16M | bicubic | | fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear | | resnet34 | 75.110 (24.890) | 92.284 (7.716) | 22M | bilinear | | seresnet34 | 74.808 (25.192) | 92.124 (7.876) | 22M | bilinear | @@ -120,8 +124,6 @@ I've leveraged the training scripts in this repository to train a few of the mod | tf_efficientnet_b0 *tfp | 76.828 (23.172) | 93.226 (6.774) | 5.29 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) | | tf_efficientnet_b0 | 76.528 (23.472) | 93.010 (6.990) | 5.29 | bicubic | [Google](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) | | gluon_resnet34_v1b | 74.580 (25.420) | 91.988 (8.012) | 21.80 | bicubic | | -| tflite_semnasnet_100 | 73.086 (26.914) | 91.336 (8.664) | 3.87 | bicubic | [Google TFLite](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet) | -| tflite_mnasnet_100 | 72.398 (27.602) | 90.930 (9.070) | 4.36 | bicubic | [Google TFLite](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet) | gluon_resnet18_v1b | 70.830 (29.170) | 89.756 (10.244) | 11.69 | bicubic | | #### @ 240x240 diff --git a/results/results-all.csv b/results/results-all.csv index a3cc0347..619444c2 100644 --- a/results/results-all.csv +++ b/results/results-all.csv @@ -2,8 +2,6 @@ model,top1,top1_err,top5,top5_err,param_count,img_size,cropt_pct,interpolation resnet18,69.758,30.242,89.078,10.922,11.69,224,0.875,bilinear gluon_resnet18_v1b,70.83,29.17,89.756,10.244,11.69,224,0.875,bicubic seresnet18,71.758,28.242,90.334,9.666,11.78,224,0.875,bicubic -tflite_mnasnet_100,72.4,27.6,90.936,9.064,4.36,224,0.875,bicubic -tflite_semnasnet_100,73.078,26.922,91.334,8.666,3.87,224,0.875,bicubic tv_resnet34,73.314,26.686,91.42,8.58,21.8,224,0.875,bilinear spnasnet_100,74.08,25.92,91.832,8.168,4.42,224,0.875,bilinear gluon_resnet34_v1b,74.58,25.42,91.988,8.012,21.8,224,0.875,bicubic @@ -12,12 +10,14 @@ densenet121,74.752,25.248,92.152,7.848,7.98,224,0.875,bicubic seresnet34,74.808,25.192,92.126,7.874,21.96,224,0.875,bilinear resnet34,75.112,24.888,92.288,7.712,21.8,224,0.875,bilinear fbnetc_100,75.12,24.88,92.386,7.614,5.57,224,0.875,bilinear +resnet26,75.292,24.708,92.57,7.43,16,224,0.875,bicubic semnasnet_100,75.456,24.544,92.592,7.408,3.89,224,0.875,bicubic mobilenetv3_100,75.628,24.372,92.708,7.292,5.48,224,0.875,bicubic densenet169,75.912,24.088,93.024,6.976,14.15,224,0.875,bicubic tv_resnet50,76.13,23.87,92.862,7.138,25.56,224,0.875,bilinear dpn68,76.306,23.694,92.97,7.03,12.61,224,0.875,bicubic tf_efficientnet_b0,76.528,23.472,93.01,6.99,5.29,224,0.875,bicubic +resnet26d,76.68,23.32,93.166,6.834,16.01,224,0.875,bicubic efficientnet_b0,76.914,23.086,93.206,6.794,5.29,224,0.875,bicubic seresnext26_32x4d,77.1,22.9,93.31,6.69,16.79,224,0.875,bicubic densenet201,77.29,22.71,93.478,6.522,20.01,224,0.875,bicubic @@ -30,7 +30,7 @@ gluon_resnet50_v1b,77.578,22.422,93.718,6.282,25.56,224,0.875,bicubic tv_resnext50_32x4d,77.618,22.382,93.698,6.302,25.03,224,0.875,bilinear seresnet50,77.636,22.364,93.752,6.248,28.09,224,0.875,bilinear tf_inception_v3,77.856,22.144,93.644,6.356,23.83,299,0.875,bicubic -gluon_resnet50_v1c,78.01,21.99,93.988,6.012,25.58,224,0.875,bicubic +gluon_resnet50_v1c,78.012,21.988,93.988,6.012,25.58,224,0.875,bicubic resnet152,78.312,21.688,94.046,5.954,60.19,224,0.875,bilinear seresnet101,78.396,21.604,94.258,5.742,49.33,224,0.875,bilinear wide_resnet50_2,78.468,21.532,94.086,5.914,68.88,224,0.875,bilinear @@ -51,6 +51,7 @@ gluon_resnext50_32x4d,79.356,20.644,94.424,5.576,25.03,224,0.875,bicubic gluon_resnet101_v1c,79.544,20.456,94.586,5.414,44.57,224,0.875,bicubic tf_efficientnet_b2,79.606,20.394,94.712,5.288,9.11,260,0.89,bicubic dpn98,79.636,20.364,94.594,5.406,61.57,224,0.875,bicubic +resnext50d_32x4d,79.674,20.326,94.868,5.132,25.05,224,0.875,bicubic gluon_resnet152_v1b,79.692,20.308,94.738,5.262,60.19,224,0.875,bicubic efficientnet_b2,79.752,20.248,94.71,5.29,9.11,260,0.89,bicubic dpn131,79.828,20.172,94.704,5.296,79.25,224,0.875,bicubic diff --git a/results/results-inv2-matched-frequency.csv b/results/results-inv2-matched-frequency.csv new file mode 100644 index 00000000..14e85131 --- /dev/null +++ b/results/results-inv2-matched-frequency.csv @@ -0,0 +1,84 @@ +model,top1,top1_err,top5,top5_err,param_count,img_size,cropt_pct,interpolation +resnet18,57.18,42.82,80.19,19.81,11.69,224,0.875,bilinear +gluon_resnet18_v1b,58.32,41.68,80.96,19.04,11.69,224,0.875,bicubic +seresnet18,59.81,40.19,81.68,18.32,11.78,224,0.875,bicubic +tv_resnet34,61.2,38.8,82.72,17.28,21.8,224,0.875,bilinear +spnasnet_100,61.21,38.79,82.77,17.23,4.42,224,0.875,bilinear +mnasnet_100,61.91,38.09,83.71,16.29,4.38,224,0.875,bicubic +fbnetc_100,62.43,37.57,83.39,16.61,5.57,224,0.875,bilinear +gluon_resnet34_v1b,62.56,37.44,84,16,21.8,224,0.875,bicubic +resnet34,62.82,37.18,84.12,15.88,21.8,224,0.875,bilinear +seresnet34,62.89,37.11,84.22,15.78,21.96,224,0.875,bilinear +densenet121,62.94,37.06,84.26,15.74,7.98,224,0.875,bicubic +semnasnet_100,63.12,36.88,84.53,15.47,3.89,224,0.875,bicubic +mobilenetv3_100,63.23,36.77,84.52,15.48,5.48,224,0.875,bicubic +tv_resnet50,63.33,36.67,84.65,15.35,25.56,224,0.875,bilinear +resnet26,63.45,36.55,84.27,15.73,16,224,0.875,bicubic +tf_efficientnet_b0,63.53,36.47,84.88,15.12,5.29,224,0.875,bicubic +dpn68,64.22,35.78,85.18,14.82,12.61,224,0.875,bicubic +efficientnet_b0,64.58,35.42,85.89,14.11,5.29,224,0.875,bicubic +resnet26d,64.63,35.37,85.12,14.88,16.01,224,0.875,bicubic +densenet169,64.78,35.22,85.25,14.75,14.15,224,0.875,bicubic +seresnext26_32x4d,65.04,34.96,85.65,14.35,16.79,224,0.875,bicubic +densenet201,65.28,34.72,85.67,14.33,20.01,224,0.875,bicubic +dpn68b,65.6,34.4,85.94,14.06,12.61,224,0.875,bicubic +resnet101,65.68,34.32,85.98,14.02,44.55,224,0.875,bilinear +densenet161,65.85,34.15,86.46,13.54,28.68,224,0.875,bicubic +gluon_resnet50_v1b,66.04,33.96,86.27,13.73,25.56,224,0.875,bicubic +inception_v3,66.12,33.88,86.34,13.66,27.16,299,0.875,bicubic +tv_resnext50_32x4d,66.18,33.82,86.04,13.96,25.03,224,0.875,bilinear +seresnet50,66.24,33.76,86.33,13.67,28.09,224,0.875,bilinear +tf_inception_v3,66.41,33.59,86.68,13.32,23.83,299,0.875,bicubic +tf_efficientnet_b1,66.52,33.48,86.68,13.32,7.79,240,0.882,bicubic +gluon_resnet50_v1c,66.54,33.46,86.16,13.84,25.58,224,0.875,bicubic +adv_inception_v3,66.6,33.4,86.56,13.44,23.83,299,0.875,bicubic +wide_resnet50_2,66.65,33.35,86.81,13.19,68.88,224,0.875,bilinear +wide_resnet101_2,66.68,33.32,87.04,12.96,126.89,224,0.875,bilinear +resnet50,66.81,33.19,87,13,25.56,224,0.875,bicubic +resnext50_32x4d,66.88,33.12,86.36,13.64,25.03,224,0.875,bicubic +resnet152,67.02,32.98,87.57,12.43,60.19,224,0.875,bilinear +gluon_resnet50_v1s,67.1,32.9,86.86,13.14,25.68,224,0.875,bicubic +seresnet101,67.15,32.85,87.05,12.95,49.33,224,0.875,bilinear +tf_efficientnet_b2,67.4,32.6,87.58,12.42,9.11,260,0.89,bicubic +gluon_resnet101_v1b,67.45,32.55,87.23,12.77,44.55,224,0.875,bicubic +efficientnet_b1,67.55,32.45,87.29,12.71,7.79,240,0.882,bicubic +seresnet152,67.55,32.45,87.39,12.61,66.82,224,0.875,bilinear +gluon_resnet101_v1c,67.56,32.44,87.16,12.84,44.57,224,0.875,bicubic +gluon_inception_v3,67.59,32.41,87.46,12.54,23.83,299,0.875,bicubic +xception,67.67,32.33,87.57,12.43,22.86,299,0.8975,bicubic +efficientnet_b2,67.8,32.2,88.2,11.8,9.11,260,0.89,bicubic +resnext101_32x8d,67.85,32.15,87.48,12.52,88.79,224,0.875,bilinear +seresnext50_32x4d,67.87,32.13,87.62,12.38,27.56,224,0.875,bilinear +gluon_resnet50_v1d,67.91,32.09,87.12,12.88,25.58,224,0.875,bicubic +dpn92,68.01,31.99,87.59,12.41,37.67,224,0.875,bicubic +gluon_resnext50_32x4d,68.28,31.72,87.32,12.68,25.03,224,0.875,bicubic +tf_efficientnet_b3,68.52,31.48,88.7,11.3,12.23,300,0.904,bicubic +dpn98,68.58,31.42,87.66,12.34,61.57,224,0.875,bicubic +gluon_seresnext50_32x4d,68.67,31.33,88.32,11.68,27.56,224,0.875,bicubic +dpn107,68.71,31.29,88.13,11.87,86.92,224,0.875,bicubic +gluon_resnet101_v1s,68.72,31.28,87.9,12.1,44.67,224,0.875,bicubic +resnext50d_32x4d,68.75,31.25,88.31,11.69,25.05,224,0.875,bicubic +dpn131,68.76,31.24,87.48,12.52,79.25,224,0.875,bicubic +gluon_resnet152_v1b,68.81,31.19,87.71,12.29,60.19,224,0.875,bicubic +gluon_resnext101_32x4d,68.96,31.04,88.34,11.66,44.18,224,0.875,bicubic +gluon_resnet101_v1d,68.99,31.01,88.08,11.92,44.57,224,0.875,bicubic +gluon_resnet152_v1c,69.13,30.87,87.89,12.11,60.21,224,0.875,bicubic +seresnext101_32x4d,69.34,30.66,88.05,11.95,48.96,224,0.875,bilinear +inception_v4,69.35,30.65,88.78,11.22,42.68,299,0.875,bicubic +ens_adv_inception_resnet_v2,69.52,30.48,88.5,11.5,55.84,299,0.8975,bicubic +gluon_resnext101_64x4d,69.69,30.31,88.26,11.74,83.46,224,0.875,bicubic +gluon_resnet152_v1d,69.95,30.05,88.47,11.53,60.21,224,0.875,bicubic +gluon_seresnext101_32x4d,70.01,29.99,88.91,11.09,48.96,224,0.875,bicubic +inception_resnet_v2,70.12,29.88,88.68,11.32,55.84,299,0.8975,bicubic +gluon_resnet152_v1s,70.32,29.68,88.87,11.13,60.32,224,0.875,bicubic +gluon_seresnext101_64x4d,70.44,29.56,89.35,10.65,88.23,224,0.875,bicubic +senet154,70.48,29.52,88.99,11.01,115.09,224,0.875,bilinear +gluon_senet154,70.6,29.4,88.92,11.08,115.09,224,0.875,bicubic +tf_efficientnet_b4,71.34,28.66,90.11,9.89,19.34,380,0.922,bicubic +nasnetalarge,72.31,27.69,90.51,9.49,88.75,331,0.875,bicubic +pnasnet5large,72.37,27.63,90.26,9.74,86.06,331,0.875,bicubic +tf_efficientnet_b5,72.56,27.44,91.1,8.9,30.39,456,0.934,bicubic +ig_resnext101_32x8d,73.66,26.34,92.15,7.85,88.79,224,0.875,bilinear +ig_resnext101_32x16d,75.71,24.29,92.9,7.1,194.03,224,0.875,bilinear +ig_resnext101_32x32d,76.84,23.16,93.19,6.81,468.53,224,0.875,bilinear +ig_resnext101_32x48d,76.87,23.13,93.32,6.68,828.41,224,0.875,bilinear diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 32ff3acf..eff83066 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -1,11 +1,13 @@ -"""Pytorch ResNet implementation w/ tweaks -This file is a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with +"""PyTorch ResNet + +This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with additional dropout and dynamic global avg/max pool. -ResNext additions added by Ross Wightman +ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants added by Ross Wightman """ import math +import torch import torch.nn as nn import torch.nn.functional as F @@ -33,6 +35,12 @@ default_cfgs = { 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), 'resnet34': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'), + 'resnet26': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth', + interpolation='bicubic'), + 'resnet26d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth', + interpolation='bicubic'), 'resnet50': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth', interpolation='bicubic'), @@ -45,6 +53,9 @@ default_cfgs = { 'resnext50_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d-068914d1.pth', interpolation='bicubic'), + 'resnext50d_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth', + interpolation='bicubic'), 'resnext101_32x4d': _cfg(url=''), 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'), 'resnext101_64x4d': _cfg(url=''), @@ -56,30 +67,57 @@ default_cfgs = { } -def conv3x3(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d( - in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) +def _get_padding(kernel_size, stride, dilation=1): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction_channels): + super(SEModule, self).__init__() + #self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d( + channels, reduction_channels, kernel_size=1, padding=0, bias=True) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d( + reduction_channels, channels, kernel_size=1, padding=0, bias=True) + + def forward(self, x): + #x_se = self.avg_pool(x) + x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) + x_se = self.fc1(x_se) + x_se = self.relu(x_se) + x_se = self.fc2(x_se) + return x * x_se.sigmoid() class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, drop_rate=0.0): + cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = nn.BatchNorm2d(planes) + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, + dilation=dilation, bias=False) + self.bn1 = norm_layer(first_planes) self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + first_planes, outplanes, kernel_size=3, padding=previous_dilation, + dilation=previous_dilation, bias=False) + self.bn2 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None self.downsample = downsample self.stride = stride - self.drop_rate = drop_rate + self.dilation = dilation def forward(self, x): residual = x @@ -87,13 +125,12 @@ class BasicBlock(nn.Module): out = self.conv1(x) out = self.bn1(out) out = self.relu(out) - - if self.drop_rate > 0.: - out = F.dropout(out, p=self.drop_rate, training=self.training) - out = self.conv2(out) out = self.bn2(out) + if self.se is not None: + out = self.se(out) + if self.downsample is not None: residual = self.downsample(x) @@ -107,22 +144,27 @@ class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, drop_rate=0.0): + cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, previous_dilation=1, norm_layer=nn.BatchNorm2d): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) - - self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(width) - self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, - padding=1, groups=cardinality, bias=False) - self.bn2 = nn.BatchNorm2d(width) - self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * 4) + first_planes = width // reduce_first + outplanes = planes * self.expansion + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.conv2 = nn.Conv2d( + first_planes, width, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(width) + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride - self.drop_rate = drop_rate + self.dilation = dilation def forward(self, x): residual = x @@ -131,9 +173,6 @@ class Bottleneck(nn.Module): out = self.bn1(out) out = self.relu(out) - if self.drop_rate > 0.: - out = F.dropout(out, p=self.drop_rate, training=self.training) - out = self.conv2(out) out = self.bn2(out) out = self.relu(out) @@ -141,6 +180,9 @@ class Bottleneck(nn.Module): out = self.conv3(out) out = self.bn3(out) + if self.se is not None: + out = self.se(out) + if self.downsample is not None: residual = self.downsample(x) @@ -151,26 +193,105 @@ class Bottleneck(nn.Module): class ResNet(nn.Module): - - def __init__(self, block, layers, num_classes=1000, in_chans=3, - cardinality=1, base_width=64, - drop_rate=0.0, block_drop_rate=0.0, - global_pool='avg'): + """ResNet / ResNeXt / SE-ResNeXt / SE-Net + + This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that + * have > 1 stride in the 3x3 conv layer of bottleneck + * have conv-bn-act ordering + + This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s + variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the + 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default. + + ResNet variants: + * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b + * c - 3 layer deep 3x3 stem, stem_width = 32 + * d - 3 layer deep 3x3 stem, stem_width = 32, average pool in downsample + * e - 3 layer deep 3x3 stem, stem_width = 64, average pool in downsample + * s - 3 layer deep 3x3 stem, stem_width = 64 + + ResNeXt + * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths + * same c,d, e, s variants as ResNet can be enabled + + SE-ResNeXt + * normal - 7x7 stem, stem_width = 64 + * same c, d, e, s variants as ResNet can be enabled + + SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64, + reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block + + Parameters + ---------- + block : Block + Class for the residual block. Options are BasicBlockGl, BottleneckGl. + layers : list of int + Numbers of layers in each block + num_classes : int, default 1000 + Number of classification classes. + in_chans : int, default 3 + Number of input (color) channels. + use_se : bool, default False + Enable Squeeze-Excitation module in blocks + cardinality : int, default 1 + Number of convolution groups for 3x3 conv in Bottleneck. + base_width : int, default 64 + Factor determining bottleneck channels. `planes * base_width / 64 * cardinality` + deep_stem : bool, default False + Whether to replace the 7x7 conv1 with 3 3x3 convolution layers. + stem_width : int, default 64 + Number of channels in stem convolutions + block_reduce_first: int, default 1 + Reduction factor for first convolution output width of residual blocks, + 1 for all archs except senets, where 2 + down_kernel_size: int, default 1 + Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets + avg_down : bool, default False + Whether to use average pooling for projection skip connection between stages/downsample. + dilated : bool, default False + Applying dilation strategy to pretrained ResNet yielding a stride-8 model, + typically used in Semantic Segmentation. + drop_rate : float, default 0. + Dropout probability before classifier, for training + global_pool : str, default 'avg' + Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' + """ + def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, + cardinality=1, base_width=64, stem_width=64, deep_stem=False, + block_reduce_first=1, down_kernel_size=1, avg_down=False, dilated=False, + norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg'): self.num_classes = num_classes - self.inplanes = 64 + self.inplanes = stem_width * 2 if deep_stem else 64 self.cardinality = cardinality self.base_width = base_width self.drop_rate = drop_rate self.expansion = block.expansion + self.dilated = dilated super(ResNet, self).__init__() - self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False) - self.bn1 = nn.BatchNorm2d(64) + + if deep_stem: + self.conv1 = nn.Sequential(*[ + nn.Conv2d(in_chans, stem_width, 3, stride=2, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(inplace=True), + nn.Conv2d(stem_width, stem_width, 3, stride=1, padding=1, bias=False), + norm_layer(stem_width), + nn.ReLU(inplace=True), + nn.Conv2d(stem_width, self.inplanes, 3, stride=1, padding=1, bias=False)]) + else: + self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0], drop_rate=block_drop_rate) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, drop_rate=block_drop_rate) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, drop_rate=block_drop_rate) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, drop_rate=block_drop_rate) + stride_3_4 = 1 if self.dilated else 2 + dilation_3 = 2 if self.dilated else 1 + dilation_4 = 4 if self.dilated else 1 + largs = dict(use_se=use_se, reduce_first=block_reduce_first, norm_layer=norm_layer, + avg_down=avg_down, down_kernel_size=down_kernel_size) + self.layer1 = self._make_layer(block, 64, layers[0], stride=1, **largs) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, **largs) + self.layer3 = self._make_layer(block, 256, layers[2], stride=stride_3_4, dilation=dilation_3, **largs) + self.layer4 = self._make_layer(block, 512, layers[3], stride=stride_3_4, dilation=dilation_4, **largs) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_features = 512 * block.expansion self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) @@ -182,18 +303,35 @@ class ResNet(nn.Module): nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) - def _make_layer(self, block, planes, blocks, stride=1, drop_rate=0.): + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, + use_se=False, avg_down=False, down_kernel_size=1, norm_layer=nn.BatchNorm2d): downsample = None + down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion), - ) - - layers = [block(self.inplanes, planes, stride, downsample, self.cardinality, self.base_width, drop_rate)] + downsample_padding = _get_padding(down_kernel_size, stride) + downsample_layers = [] + conv_stride = stride + if avg_down: + avg_stride = stride if dilation == 1 else 1 + conv_stride = 1 + downsample_layers = [nn.AvgPool2d(avg_stride, avg_stride, ceil_mode=True, count_include_pad=False)] + downsample_layers += [ + nn.Conv2d(self.inplanes, planes * block.expansion, down_kernel_size, + stride=conv_stride, padding=downsample_padding, bias=False), + norm_layer(planes * block.expansion)] + downsample = nn.Sequential(*downsample_layers) + + first_dilation = 1 if dilation in (1, 2) else 2 + layers = [block( + self.inplanes, planes, stride, downsample, + cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, + use_se=use_se, dilation=first_dilation, previous_dilation=dilation, norm_layer=norm_layer)] self.inplanes = planes * block.expansion for i in range(1, blocks): - layers.append(block(self.inplanes, planes, cardinality=self.cardinality, base_width=self.base_width)) + layers.append(block( + self.inplanes, planes, + cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, + use_se=use_se, dilation=dilation, previous_dilation=dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) @@ -257,6 +395,33 @@ def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-26 model. + """ + default_cfg = default_cfgs['resnet26'] + model = ResNet(Bottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-26 v1d model. + This is technically a 28 layer ResNet, sticking with 'd' modifier from Gluon for now. + """ + default_cfg = default_cfgs['resnet26d'] + model = ResNet( + Bottleneck, [2, 2, 2, 2], stem_width=32, deep_stem=True, avg_down=True, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-50 model. @@ -362,6 +527,21 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def resnext50d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample + """ + default_cfg = default_cfgs['resnext50d_32x4d'] + model = ResNet( + Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, + stem_width=32, deep_stem=True, avg_down=True, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt-101 32x4d model.