From f4cdc2ac319a3d76db2986e7179d99715db14e8c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 11 May 2020 23:27:09 -0700 Subject: [PATCH] Add ResNeSt models --- README.md | 1 + timm/models/__init__.py | 1 + timm/models/layers/split_attn.py | 83 ++++++++++++ timm/models/resnest.py | 214 +++++++++++++++++++++++++++++++ 4 files changed, 299 insertions(+) create mode 100644 timm/models/layers/split_attn.py create mode 100644 timm/models/resnest.py diff --git a/README.md b/README.md index 70e2a701..ac6b57ce 100644 --- a/README.md +++ b/README.md @@ -130,6 +130,7 @@ Included models: * Instagram trained / ImageNet tuned ResNeXt101-32x8d to 32x48d from from [facebookresearch](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/) * Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169) * Selective Kernel (SK) Nets (https://arxiv.org/abs/1903.06586) + * ResNeSt (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955) * DLA * Original (https://github.com/ucbdrive/dla, https://arxiv.org/abs/1707.06484) * Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index b073eb3a..d421ad45 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -18,6 +18,7 @@ from .dla import * from .hrnet import * from .sknet import * from .tresnet import * +from .resnest import * from .registry import * from .factory import create_model diff --git a/timm/models/layers/split_attn.py b/timm/models/layers/split_attn.py new file mode 100644 index 00000000..f91892de --- /dev/null +++ b/timm/models/layers/split_attn.py @@ -0,0 +1,83 @@ +""" Split Attention Conv2d (for ResNeSt Models) + +Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt + +Modified for torchscript compat, performance, and consistency with timm by Ross Wightman +""" +import torch +import torch.nn.functional as F +from torch import nn + + +class RadixSoftmax(nn.Module): + def __init__(self, radix, cardinality): + super(RadixSoftmax, self).__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplitAttnConv2d(nn.Module): + """Split-Attention Conv2d + """ + def __init__(self, in_channels, channels, kernel_size, stride=1, padding=0, + dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, + act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): + super(SplitAttnConv2d, self).__init__() + self.radix = radix + self.cardinality = groups + self.channels = channels + mid_chs = channels * radix + attn_chs = max(in_channels * radix // reduction_factor, 32) + self.conv = nn.Conv2d( + in_channels, mid_chs, kernel_size, stride, padding, dilation, + groups=groups * radix, bias=bias, **kwargs) + self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None + self.act0 = act_layer(inplace=True) + self.fc1 = nn.Conv2d(channels, attn_chs, 1, groups=self.cardinality) + self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None + self.act1 = act_layer(inplace=True) + self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=self.cardinality) + self.drop_block = drop_block + self.rsoftmax = RadixSoftmax(radix, groups) + + def forward(self, x): + x = self.conv(x) + if self.bn0 is not None: + x = self.bn0(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act0(x) + + B, RC, H, W = x.shape + if self.radix > 1: + x = x.reshape((B, self.radix, RC // self.radix, H, W)) + x_gap = torch.sum(x, dim=1) + else: + x_gap = x + x_gap = F.adaptive_avg_pool2d(x_gap, 1) + x_gap = self.fc1(x_gap) + + if self.bn1 is not None: + x_gap = self.bn1(x_gap) + x_gap = self.act1(x_gap) + + x_attn = self.fc2(x_gap) + x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) + + if self.radix > 1: + out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) + else: + out = x * x_attn + return out.contiguous() diff --git a/timm/models/resnest.py b/timm/models/resnest.py new file mode 100644 index 00000000..e4f0157b --- /dev/null +++ b/timm/models/resnest.py @@ -0,0 +1,214 @@ +""" ResNeSt Models + +Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 + +Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt + +Modified for torchscript compat, and consistency with timm by Ross Wightman +""" +import math +import torch +import torch.nn.functional as F +from torch import nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import DropBlock2d +from .helpers import load_pretrained +from .layers import SelectiveKernelConv, ConvBnAct, create_attn +from .layers.split_attn import SplitAttnConv2d +from .registry import register_model +from .resnet import ResNet + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + +default_cfgs = { + 'resnest26d': _cfg( + url=''), + 'resnest50d': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'), + 'resnest101e': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', input_size=(3, 256, 256)), + 'resnest200e': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)), + 'resnest269e': _cfg( + url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)), +} + + +class ResNestBottleneck(nn.Module): + """ResNet Bottleneck + """ + # pylint: disable=unused-argument + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): + super(ResNestBottleneck, self).__init__() + assert reduce_first == 1 # not supported + assert attn_layer is None # not supported + assert aa_layer is None # TODO not yet supported + assert drop_path is None # TODO not yet supported + + group_width = int(planes * (base_width / 64.)) * cardinality + first_dilation = first_dilation or dilation + if avd and (stride > 1 or is_first): + avd_stride = stride + stride = 1 + else: + avd_stride = 0 + self.radix = radix + + self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) + self.bn1 = norm_layer(group_width) + self.drop_block1 = drop_block if drop_block is not None else None + self.act1 = act_layer(inplace=True) + self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None + + if self.radix >= 1: + self.conv2 = SplitAttnConv2d( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, norm_layer=norm_layer, drop_block=drop_block) + self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness + self.drop_block2 = None + self.act2 = None + else: + self.conv2 = nn.Conv2d( + group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False) + self.bn2 = norm_layer(group_width) + self.drop_block2 = drop_block if drop_block is not None else None + self.act2 = act_layer(inplace=True) + self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None + + self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False) + self.bn3 = norm_layer(planes*4) + self.drop_block3 = drop_block if drop_block is not None else None + self.act3 = act_layer(inplace=True) + self.downsample = downsample + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + if self.drop_block1 is not None: + out = self.drop_block1(out) + out = self.act1(out) + + if self.avd_first is not None: + out = self.avd_first(out) + + out = self.conv2(out) + if self.bn2 is not None: + out = self.bn2(out) + if self.drop_block2 is not None: + out = self.drop_block2(out) + out = self.act2(out) + + if self.avd_last is not None: + out = self.avd_last(out) + + out = self.conv3(out) + out = self.bn3(out) + if self.drop_block3 is not None: + out = self.drop_block3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act3(out) + return out + + +@register_model +def resnest26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-26d model. + """ + default_cfg = default_cfgs['resnest26d'] + model = ResNet( + ResNestBottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample. + """ + default_cfg = default_cfgs['resnest50d'] + model = ResNet( + ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest101e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + default_cfg = default_cfgs['resnest101e'] + model = ResNet( + ResNestBottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest200e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + default_cfg = default_cfgs['resnest200e'] + model = ResNet( + ResNestBottleneck, [3, 24, 36, 3], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955 + Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample. + """ + default_cfg = default_cfgs['resnest269e'] + model = ResNet( + ResNestBottleneck, [3, 30, 48, 8], num_classes=num_classes, in_chans=in_chans, + stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, + block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model