diff --git a/README.md b/README.md index 70e2a701..3cd5c223 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,9 @@ ## What's New +### May 12, 2020 +* Add ResNeSt models (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955)) + ### May 3, 2020 * Pruned EfficientNet B1, B2, and B3 (https://arxiv.org/abs/2002.08258) contributed by [Yonathan Aflalo](https://github.com/yoniaflalo) @@ -70,41 +73,6 @@ * Add RandAugment trained EfficientNet-B0 weight with 77.7 top-1. Trained by [Michael Klachko](https://github.com/michaelklachko) with this code and recent hparams (see Training section) * Add `avg_checkpoints.py` script for post training weight averaging and update all scripts with header docstrings and shebangs. -### Dec 30, 2019 -* Merge [Dushyant Mehta's](https://github.com/mehtadushy) PR for SelecSLS (Selective Short and Long Range Skip Connections) networks. Good GPU memory consumption and throughput. Original: https://github.com/mehtadushy/SelecSLS-Pytorch - -### Dec 28, 2019 -* Add new model weights and training hparams (see Training Hparams section) - * `efficientnet_b3` - 81.5 top-1, 95.7 top-5 at default res/crop, 81.9, 95.8 at 320x320 1.0 crop-pct - * trained with RandAugment, ended up with an interesting but less than perfect result (see training section) - * `seresnext26d_32x4d`- 77.6 top-1, 93.6 top-5 - * deep stem (32, 32, 64), avgpool downsample - * stem/dowsample from bag-of-tricks paper - * `seresnext26t_32x4d`- 78.0 top-1, 93.7 top-5 - * deep tiered stem (24, 48, 64), avgpool downsample (a modified 'D' variant) - * stem sizing mods from Jeremy Howard and fastai devs discussing ResNet architecture experiments - -### Dec 23, 2019 -* Add RandAugment trained MixNet-XL weights with 80.48 top-1. -* `--dist-bn` argument added to train.py, will distribute BN stats between nodes after each train epoch, before eval - -### Dec 4, 2019 -* Added weights from the first training from scratch of an EfficientNet (B2) with my new RandAugment implementation. Much better than my previous B2 and very close to the official AdvProp ones (80.4 top-1, 95.08 top-5). - -### Nov 29, 2019 -* Brought EfficientNet and MobileNetV3 up to date with my https://github.com/rwightman/gen-efficientnet-pytorch code. Torchscript and ONNX export compat excluded. - * AdvProp weights added - * Official TF MobileNetv3 weights added -* EfficientNet and MobileNetV3 hook based 'feature extraction' classes added. Will serve as basis for using models as backbones in obj detection/segmentation tasks. Lots more to be done here... -* HRNet classification models and weights added from https://github.com/HRNet/HRNet-Image-Classification -* Consistency in global pooling, `reset_classifer`, and `forward_features` across models - * `forward_features` always returns unpooled feature maps now -* Reasonable chance I broke something... let me know - -### Nov 22, 2019 -* Add ImageNet training RandAugment implementation alongside AutoAugment. PyTorch Transform compatible format, using PIL. Currently training two EfficientNet models from scratch with promising results... will update. -* `drop-connect` cmd line arg finally added to `train.py`, no need to hack model fns. Works for efficientnet/mobilenetv3 based models, ignored otherwise. - ## Introduction For each competition, personal, or freelance project involving images + Convolution Neural Networks, I build on top of an evolving collection of code and models. This repo contains a (somewhat) cleaned up and paired down iteration of that code. Hopefully it'll be of use to others. @@ -130,6 +98,7 @@ Included models: * Instagram trained / ImageNet tuned ResNeXt101-32x8d to 32x48d from from [facebookresearch](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/) * Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169) * Selective Kernel (SK) Nets (https://arxiv.org/abs/1903.06586) + * ResNeSt (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955) * DLA * Original (https://github.com/ucbdrive/dla, https://arxiv.org/abs/1707.06484) * Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index b073eb3a..d421ad45 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -18,6 +18,7 @@ from .dla import * from .hrnet import * from .sknet import * from .tresnet import * +from .resnest import * from .registry import * from .factory import create_model diff --git a/timm/models/layers/split_attn.py b/timm/models/layers/split_attn.py new file mode 100644 index 00000000..383c4583 --- /dev/null +++ b/timm/models/layers/split_attn.py @@ -0,0 +1,81 @@ +""" Split Attention Conv2d (for ResNeSt Models) + +Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt + +Modified for torchscript compat, performance, and consistency with timm by Ross Wightman +""" +import torch +import torch.nn.functional as F +from torch import nn + + +class RadixSoftmax(nn.Module): + def __init__(self, radix, cardinality): + super(RadixSoftmax, self).__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttnConv2d(nn.Module): + """Split-Attention Conv2d + """ + def __init__(self, in_channels, channels, kernel_size, stride=1, padding=0, + dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, + act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): + super(SplitAttnConv2d, self).__init__() + self.radix = radix + self.cardinality = groups + self.channels = channels + mid_chs = channels * radix + attn_chs = max(in_channels * radix // reduction_factor, 32) + self.conv = nn.Conv2d( + in_channels, mid_chs, kernel_size, stride, padding, dilation, + groups=groups * radix, bias=bias, **kwargs) + self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None + self.act0 = act_layer(inplace=True) + self.fc1 = nn.Conv2d(channels, attn_chs, 1, groups=self.cardinality) + self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None + self.act1 = act_layer(inplace=True) + self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=self.cardinality) + self.drop_block = drop_block + self.rsoftmax = RadixSoftmax(radix, groups) + + def forward(self, x): + x = self.conv(x) + if self.bn0 is not None: + x = self.bn0(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act0(x) + + B, RC, H, W = x.shape + if self.radix > 1: + x = x.reshape((B, self.radix, RC // self.radix, H, W)) + x_gap = torch.sum(x, dim=1) + else: + x_gap = x + x_gap = F.adaptive_avg_pool2d(x_gap, 1) + x_gap = self.fc1(x_gap) + if self.bn1 is not None: + x_gap = self.bn1(x_gap) + x_gap = self.act1(x_gap) + x_attn = self.fc2(x_gap) + + x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) + if self.radix > 1: + out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) + else: + out = x * x_attn + return out.contiguous() diff --git a/timm/models/resnest.py b/timm/models/resnest.py new file mode 100644 index 00000000..849543ba --- /dev/null +++ b/timm/models/resnest.py @@ -0,0 +1,267 @@ +""" ResNeSt Models + +Paper: `ResNeSt: Split-Attention Networks` - https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt by Hang Zhang + +Modified for torchscript compat, and consistency with timm by Ross Wightman +""" +import math +import torch +import torch.nn.functional as F +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import DropBlock2d +from .helpers import load_pretrained +from .layers import SelectiveKernelConv, ConvBnAct, create_attn +from .layers.split_attn import SplitAttnConv2d +from .registry import register_model +from .resnet import ResNet + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + +default_cfgs = { + 'resnest14d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'), + 'resnest26d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'), + 'resnest50d': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'), + 'resnest101e': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', input_size=(3, 256, 256)), + 'resnest200e': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)), + 'resnest269e': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)), + 'resnest50d_4s2x40d': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_4s2x40d-41d14ed0.pth', + interpolation='bicubic'), + 'resnest50d_1s4x24d': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_1s4x24d-d4a4f76f.pth', + interpolation='bicubic') +} + + +class ResNestBottleneck(nn.Module): + """ResNet Bottleneck + """ + # pylint: disable=unused-argument + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(ResNestBottleneck, self).__init__() + assert reduce_first == 1 # not supported + assert attn_layer is None # not supported + assert aa_layer is None # TODO not yet supported + assert drop_path is None # TODO not yet supported + + group_width = int(planes * (base_width / 64.)) * cardinality + first_dilation = first_dilation or dilation + if avd and (stride > 1 or is_first): + avd_stride = stride + stride = 1 + else: + avd_stride = 0 + self.radix = radix + + self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) + self.bn1 = norm_layer(group_width) + self.drop_block1 = drop_block if drop_block is not None else None + self.act1 = act_layer(inplace=True) + self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None + + if self.radix >= 1: + self.conv2 = SplitAttnConv2d( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block) + self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness + self.drop_block2 = None + self.act2 = None + else: + self.conv2 = nn.Conv2d( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(group_width) + self.drop_block2 = drop_block if drop_block is not None else None + self.act2 = act_layer(inplace=True) + self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None + + self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes*4) + self.drop_block3 = drop_block if drop_block is not None else None + self.act3 = act_layer(inplace=True) + self.downsample = downsample + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + if self.drop_block1 is not None: + out = self.drop_block1(out) + out = self.act1(out) + + if self.avd_first is not None: + out = self.avd_first(out) + + out = self.conv2(out) + if self.bn2 is not None: + out = self.bn2(out) + if self.drop_block2 is not None: + out = self.drop_block2(out) + out = self.act2(out) + + if self.avd_last is not None: + out = self.avd_last(out) + + out = self.conv3(out) + out = self.bn3(out) + if self.drop_block3 is not None: + out = self.drop_block3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act3(out) + return out + + +@register_model +def resnest14d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-14d model. Weights ported from GluonCV. + """ + default_cfg = default_cfgs['resnest14d'] + model = ResNet( + ResNestBottleneck, [1, 1, 1, 1], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-26d model. Weights ported from GluonCV. + """ + default_cfg = default_cfgs['resnest26d'] + model = ResNet( + ResNestBottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample. + """ + default_cfg = default_cfgs['resnest50d'] + model = ResNet( + ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest101e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + default_cfg = default_cfgs['resnest101e'] + model = ResNet( + ResNestBottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest200e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + default_cfg = default_cfgs['resnest200e'] + model = ResNet( + ResNestBottleneck, [3, 24, 36, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + default_cfg = default_cfgs['resnest269e'] + model = ResNet( + ResNestBottleneck, [3, 30, 48, 8], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest50d_4s2x40d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md + """ + default_cfg = default_cfgs['resnest50d_4s2x40d'] + model = ResNet( + ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2, + block_args=dict(radix=4, avd=True, avd_first=True), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest50d_1s4x24d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md + """ + default_cfg = default_cfgs['resnest50d_1s4x24d'] + model = ResNet( + ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4, + block_args=dict(radix=1, avd=True, avd_first=True), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model