From c8b3d6b81a478ec72b8d5f75015b3859af926df1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jan 2020 19:45:05 -0800 Subject: [PATCH] Initial impl of Selective Kernel Networks. Very much a WIP. --- timm/models/resnet.py | 146 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 422eb0cb..b3adf6dc 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -6,6 +6,7 @@ additional dropout and dynamic global avg/max pool. ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman """ import math +from collections import OrderedDict import torch import torch.nn as nn @@ -100,6 +101,7 @@ default_cfgs = { 'seresnext26tn_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', interpolation='bicubic'), + 'skresnet26d': _cfg() } @@ -232,6 +234,137 @@ class Bottleneck(nn.Module): return out +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, num_attn_feat=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, num_attn_feat, kernel_size=1, bias=False) + self.bn = norm_layer(num_attn_feat) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(num_attn_feat, channels * num_paths, kernel_size=1) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + #print('attn sum', x.shape) + x = self.pool(x) + #print('attn pool', x.shape) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + #print('attn sel', x.shape) + x = x.view((x.shape[0], self.num_paths, x.shape[1]//self.num_paths) + x.shape[-2:]) + #print('attn spl', x.shape) + x = torch.softmax(x, dim=1) + return x + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=[3, 5], attn_reduction=16, + min_attn_feat=32, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + if not isinstance(kernel_size, list): + assert kernel_size >= 3 and kernel_size % 2 + kernel_size = [kernel_size] * 2 + else: + # FIXME assert kernel sizes >=3 and odd + pass + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + groups = min(out_channels // len(kernel_size), groups) + + self.conv_paths = nn.ModuleList() + for k, d in zip(kernel_size, dilation): + p = _get_padding(k, stride, d) + self.conv_paths.append(nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)), + ('bn', norm_layer(out_channels)), + ('act', act_layer(inplace=True)) + ]))) + + if use_attn: + num_attn_feat = max(int(out_channels / attn_reduction), min_attn_feat) + self.attn = SelectiveKernelAttn(out_channels, len(kernel_size), num_attn_feat) + else: + self.attn = None + + def forward(self, x): + x_paths = [] + for conv in self.conv_paths: + xk = conv(x) + x_paths.append(xk) + if self.attn is not None: + x_paths = torch.stack(x_paths, dim=1) + # print('paths', x_paths.shape) + x_attn = self.attn(x_paths) + #print('attn', x_attn.shape) + x = x_paths * x_attn + #print('amul', x.shape) + x = torch.sum(x, dim=1) + #print('asum', x.shape) + else: + x = torch.cat(x_paths, dim=1) + return x + + +class SelectiveKernelBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBottleneck, self).__init__() + + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + self.conv2 = SelectiveKernelConv( + first_planes, width, stride=stride, dilation=dilation, groups=cardinality) + self.bn2 = norm_layer(width) + self.act2 = act_layer(inplace=True) + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act2(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act3(out) + + return out + + class ResNet(nn.Module): """ResNet / ResNeXt / SE-ResNeXt / SE-Net @@ -472,6 +605,19 @@ def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-26 model. + """ + default_cfg = default_cfgs['skresnet26d'] + model = ResNet( + SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + @register_model def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-26 v1d model.