|
|
|
@ -6,6 +6,7 @@ additional dropout and dynamic global avg/max pool.
|
|
|
|
|
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
|
|
|
|
|
"""
|
|
|
|
|
import math
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -100,6 +101,7 @@ default_cfgs = {
|
|
|
|
|
'seresnext26tn_32x4d': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth',
|
|
|
|
|
interpolation='bicubic'),
|
|
|
|
|
'skresnet26d': _cfg()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -232,6 +234,137 @@ class Bottleneck(nn.Module):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SelectiveKernelAttn(nn.Module):
|
|
|
|
|
def __init__(self, channels, num_paths=2, num_attn_feat=32,
|
|
|
|
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
|
|
|
|
super(SelectiveKernelAttn, self).__init__()
|
|
|
|
|
self.num_paths = num_paths
|
|
|
|
|
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
|
self.fc_reduce = nn.Conv2d(channels, num_attn_feat, kernel_size=1, bias=False)
|
|
|
|
|
self.bn = norm_layer(num_attn_feat)
|
|
|
|
|
self.act = act_layer(inplace=True)
|
|
|
|
|
self.fc_select = nn.Conv2d(num_attn_feat, channels * num_paths, kernel_size=1)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
assert x.shape[1] == self.num_paths
|
|
|
|
|
x = torch.sum(x, dim=1)
|
|
|
|
|
#print('attn sum', x.shape)
|
|
|
|
|
x = self.pool(x)
|
|
|
|
|
#print('attn pool', x.shape)
|
|
|
|
|
x = self.fc_reduce(x)
|
|
|
|
|
x = self.bn(x)
|
|
|
|
|
x = self.act(x)
|
|
|
|
|
x = self.fc_select(x)
|
|
|
|
|
#print('attn sel', x.shape)
|
|
|
|
|
x = x.view((x.shape[0], self.num_paths, x.shape[1]//self.num_paths) + x.shape[-2:])
|
|
|
|
|
#print('attn spl', x.shape)
|
|
|
|
|
x = torch.softmax(x, dim=1)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SelectiveKernelConv(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=[3, 5], attn_reduction=16,
|
|
|
|
|
min_attn_feat=32, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True,
|
|
|
|
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
|
|
|
|
super(SelectiveKernelConv, self).__init__()
|
|
|
|
|
if not isinstance(kernel_size, list):
|
|
|
|
|
assert kernel_size >= 3 and kernel_size % 2
|
|
|
|
|
kernel_size = [kernel_size] * 2
|
|
|
|
|
else:
|
|
|
|
|
# FIXME assert kernel sizes >=3 and odd
|
|
|
|
|
pass
|
|
|
|
|
if keep_3x3:
|
|
|
|
|
dilation = [dilation * (k - 1) // 2 for k in kernel_size]
|
|
|
|
|
kernel_size = [3] * len(kernel_size)
|
|
|
|
|
else:
|
|
|
|
|
dilation = [dilation] * len(kernel_size)
|
|
|
|
|
groups = min(out_channels // len(kernel_size), groups)
|
|
|
|
|
|
|
|
|
|
self.conv_paths = nn.ModuleList()
|
|
|
|
|
for k, d in zip(kernel_size, dilation):
|
|
|
|
|
p = _get_padding(k, stride, d)
|
|
|
|
|
self.conv_paths.append(nn.Sequential(OrderedDict([
|
|
|
|
|
('conv', nn.Conv2d(
|
|
|
|
|
in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)),
|
|
|
|
|
('bn', norm_layer(out_channels)),
|
|
|
|
|
('act', act_layer(inplace=True))
|
|
|
|
|
])))
|
|
|
|
|
|
|
|
|
|
if use_attn:
|
|
|
|
|
num_attn_feat = max(int(out_channels / attn_reduction), min_attn_feat)
|
|
|
|
|
self.attn = SelectiveKernelAttn(out_channels, len(kernel_size), num_attn_feat)
|
|
|
|
|
else:
|
|
|
|
|
self.attn = None
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x_paths = []
|
|
|
|
|
for conv in self.conv_paths:
|
|
|
|
|
xk = conv(x)
|
|
|
|
|
x_paths.append(xk)
|
|
|
|
|
if self.attn is not None:
|
|
|
|
|
x_paths = torch.stack(x_paths, dim=1)
|
|
|
|
|
# print('paths', x_paths.shape)
|
|
|
|
|
x_attn = self.attn(x_paths)
|
|
|
|
|
#print('attn', x_attn.shape)
|
|
|
|
|
x = x_paths * x_attn
|
|
|
|
|
#print('amul', x.shape)
|
|
|
|
|
x = torch.sum(x, dim=1)
|
|
|
|
|
#print('asum', x.shape)
|
|
|
|
|
else:
|
|
|
|
|
x = torch.cat(x_paths, dim=1)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SelectiveKernelBottleneck(nn.Module):
|
|
|
|
|
expansion = 4
|
|
|
|
|
|
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
|
|
|
|
cardinality=1, base_width=64, use_se=False,
|
|
|
|
|
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
|
|
|
|
super(SelectiveKernelBottleneck, self).__init__()
|
|
|
|
|
|
|
|
|
|
width = int(math.floor(planes * (base_width / 64)) * cardinality)
|
|
|
|
|
first_planes = width // reduce_first
|
|
|
|
|
outplanes = planes * self.expansion
|
|
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
|
|
|
|
|
self.bn1 = norm_layer(first_planes)
|
|
|
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
|
self.conv2 = SelectiveKernelConv(
|
|
|
|
|
first_planes, width, stride=stride, dilation=dilation, groups=cardinality)
|
|
|
|
|
self.bn2 = norm_layer(width)
|
|
|
|
|
self.act2 = act_layer(inplace=True)
|
|
|
|
|
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
|
|
|
|
self.bn3 = norm_layer(outplanes)
|
|
|
|
|
self.act3 = act_layer(inplace=True)
|
|
|
|
|
self.downsample = downsample
|
|
|
|
|
self.stride = stride
|
|
|
|
|
self.dilation = dilation
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
residual = x
|
|
|
|
|
|
|
|
|
|
out = self.conv1(x)
|
|
|
|
|
out = self.bn1(out)
|
|
|
|
|
out = self.act1(out)
|
|
|
|
|
|
|
|
|
|
out = self.conv2(out)
|
|
|
|
|
out = self.bn2(out)
|
|
|
|
|
out = self.act2(out)
|
|
|
|
|
|
|
|
|
|
out = self.conv3(out)
|
|
|
|
|
out = self.bn3(out)
|
|
|
|
|
|
|
|
|
|
if self.downsample is not None:
|
|
|
|
|
residual = self.downsample(x)
|
|
|
|
|
|
|
|
|
|
out += residual
|
|
|
|
|
out = self.act3(out)
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResNet(nn.Module):
|
|
|
|
|
"""ResNet / ResNeXt / SE-ResNeXt / SE-Net
|
|
|
|
|
|
|
|
|
@ -472,6 +605,19 @@ def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-26 model.
|
|
|
|
|
"""
|
|
|
|
|
default_cfg = default_cfgs['skresnet26d']
|
|
|
|
|
model = ResNet(
|
|
|
|
|
SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
|
|
|
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
if pretrained:
|
|
|
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|
|
|
|
"""Constructs a ResNet-26 v1d model.
|
|
|
|
|