Move Selective Kernel blocks/convs to their own sknet.py file

pull/87/head
Ross Wightman 4 years ago
parent a93bae6dc5
commit 58e28dc7e7

@ -16,6 +16,7 @@ from .gluon_xception import *
from .res2net import *
from .dla import *
from .hrnet import *
from .sknet import *
from .registry import *
from .factory import create_model

@ -6,7 +6,6 @@ additional dropout and dynamic global avg/max pool.
ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
"""
import math
from collections import OrderedDict
import torch
import torch.nn as nn
@ -101,11 +100,10 @@ default_cfgs = {
'seresnext26tn_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth',
interpolation='bicubic'),
'skresnet26d': _cfg()
}
def _get_padding(kernel_size, stride, dilation=1):
def get_padding(kernel_size, stride, dilation=1):
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
@ -234,184 +232,6 @@ class Bottleneck(nn.Module):
return out
class SelectiveKernelAttn(nn.Module):
def __init__(self, channels, num_paths=2, num_attn_feat=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelAttn, self).__init__()
self.num_paths = num_paths
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc_reduce = nn.Conv2d(channels, num_attn_feat, kernel_size=1, bias=False)
self.bn = norm_layer(num_attn_feat)
self.act = act_layer(inplace=True)
self.fc_select = nn.Conv2d(num_attn_feat, channels * num_paths, kernel_size=1, bias=False)
def forward(self, x):
assert x.shape[1] == self.num_paths
x = torch.sum(x, dim=1)
#print('attn sum', x.shape)
x = self.pool(x)
#print('attn pool', x.shape)
x = self.fc_reduce(x)
x = self.bn(x)
x = self.act(x)
x = self.fc_select(x)
#print('attn sel', x.shape)
x = x.view((x.shape[0], self.num_paths, x.shape[1]//self.num_paths) + x.shape[-2:])
#print('attn spl', x.shape)
x = torch.softmax(x, dim=1)
return x
class SelectiveKernelConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=[3, 5], attn_reduction=16,
min_attn_feat=16, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelConv, self).__init__()
if not isinstance(kernel_size, list):
assert kernel_size >= 3 and kernel_size % 2
kernel_size = [kernel_size] * 2
else:
# FIXME assert kernel sizes >=3 and odd
pass
if keep_3x3:
dilation = [dilation * (k - 1) // 2 for k in kernel_size]
kernel_size = [3] * len(kernel_size)
else:
dilation = [dilation] * len(kernel_size)
groups = min(out_channels // len(kernel_size), groups)
self.conv_paths = nn.ModuleList()
for k, d in zip(kernel_size, dilation):
p = _get_padding(k, stride, d)
self.conv_paths.append(nn.Sequential(OrderedDict([
('conv', nn.Conv2d(
in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)),
('bn', norm_layer(out_channels)),
('act', act_layer(inplace=True))
])))
if use_attn:
num_attn_feat = max(int(out_channels / attn_reduction), min_attn_feat)
self.attn = SelectiveKernelAttn(out_channels, len(kernel_size), num_attn_feat)
else:
self.attn = None
def forward(self, x):
x_paths = []
for conv in self.conv_paths:
xk = conv(x)
x_paths.append(xk)
if self.attn is not None:
x_paths = torch.stack(x_paths, dim=1)
# print('paths', x_paths.shape)
x_attn = self.attn(x_paths)
#print('attn', x_attn.shape)
x = x_paths * x_attn
#print('amul', x.shape)
x = torch.sum(x, dim=1)
#print('asum', x.shape)
else:
x = torch.cat(x_paths, dim=1)
return x
class SelectiveKernelBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, use_se=False,
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelBasicBlock, self).__init__()
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
assert base_width == 64, 'BasicBlock doest not support changing base width'
first_planes = planes // reduce_first
outplanes = planes * self.expansion
self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
dilation=dilation, bias=False)
self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True)
self.conv2 = SelectiveKernelConv(first_planes, outplanes, dilation=previous_dilation)
self.bn2 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.act2 = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.se is not None:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.act2(out)
return out
class SelectiveKernelBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, use_se=False,
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelBottleneck, self).__init__()
width = int(math.floor(planes * (base_width / 64)) * cardinality)
first_planes = width // reduce_first
outplanes = planes * self.expansion
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True)
self.conv2 = SelectiveKernelConv(
first_planes, width, stride=stride, dilation=dilation, groups=cardinality)
self.bn2 = norm_layer(width)
self.act2 = act_layer(inplace=True)
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.act3 = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.act2(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.act3(out)
return out
class ResNet(nn.Module):
"""ResNet / ResNeXt / SE-ResNeXt / SE-Net
@ -560,7 +380,7 @@ class ResNet(nn.Module):
downsample = None
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
if stride != 1 or self.inplanes != planes * block.expansion:
downsample_padding = _get_padding(down_kernel_size, stride)
downsample_padding = get_padding(down_kernel_size, stride)
downsample_layers = []
conv_stride = stride
if avg_down:
@ -628,18 +448,6 @@ def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model
@register_model
def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-18 model.
"""
default_cfg = default_cfgs['resnet18']
model = ResNet(SelectiveKernelBasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-34 model.
@ -664,19 +472,6 @@ def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
return model
@register_model
def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-26 model.
"""
default_cfg = default_cfgs['skresnet26d']
model = ResNet(
SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-26 v1d model.

@ -0,0 +1,294 @@
import math
from collections import OrderedDict
import torch
from torch import nn as nn
from timm.models.registry import register_model
from timm.models.helpers import load_pretrained
from timm.models.resnet import ResNet, get_padding, SEModule
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv1', 'classifier': 'fc',
**kwargs
}
default_cfgs = {
'skresnet18': _cfg(url=''),
'skresnet26d': _cfg()
}
class SelectiveKernelAttn(nn.Module):
def __init__(self, channels, num_paths=2, attn_channels=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelAttn, self).__init__()
self.num_paths = num_paths
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
self.bn = norm_layer(attn_channels)
self.act = act_layer(inplace=True)
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
def forward(self, x):
assert x.shape[1] == self.num_paths
x = torch.sum(x, dim=1)
#print('attn sum', x.shape)
x = self.pool(x)
#print('attn pool', x.shape)
x = self.fc_reduce(x)
x = self.bn(x)
x = self.act(x)
x = self.fc_select(x)
#print('attn sel', x.shape)
B, C, H, W = x.shape
x = x.view(B, self.num_paths, C // self.num_paths, H, W)
#print('attn spl', x.shape)
x = torch.softmax(x, dim=1)
return x
def _kernel_valid(k):
if isinstance(k, (list, tuple)):
for ki in k:
return _kernel_valid(ki)
assert k >= 3 and k % 2
class SelectiveKernelConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=[3, 5], stride=1, dilation=1, groups=1,
attn_reduction=16, min_attn_channels=32, keep_3x3=True, use_attn=True,
split_input=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelConv, self).__init__()
_kernel_valid(kernel_size)
if not isinstance(kernel_size, list):
kernel_size = [kernel_size] * 2
if keep_3x3:
dilation = [dilation * (k - 1) // 2 for k in kernel_size]
kernel_size = [3] * len(kernel_size)
else:
dilation = [dilation] * len(kernel_size)
num_paths = len(kernel_size)
self.num_paths = num_paths
self.split_input = split_input
self.in_channels = in_channels
self.out_channels = out_channels
if split_input:
assert in_channels % num_paths == 0 and out_channels % num_paths == 0
in_channels = in_channels // num_paths
out_channels = out_channels // num_paths
groups = min(out_channels, groups)
self.paths = nn.ModuleList()
for k, d in zip(kernel_size, dilation):
p = get_padding(k, stride, d)
self.paths.append(nn.Sequential(OrderedDict([
('conv', nn.Conv2d(
in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)),
('bn', norm_layer(out_channels)),
('act', act_layer(inplace=True))
])))
if use_attn:
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels)
else:
self.attn = None
def forward(self, x):
if self.split_input:
x_split = torch.split(x, self.out_channels // self.num_paths, 1)
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
else:
x_paths = [op(x) for op in self.paths]
if self.attn is not None:
x = torch.stack(x_paths, dim=1)
# print('paths', x_paths.shape)
x_attn = self.attn(x)
#print('attn', x_attn.shape)
x = x * x_attn
#print('amul', x.shape)
if self.split_input:
B, N, C, H, W = x.shape
x = x.reshape(B, N * C, H, W)
else:
x = torch.sum(x, dim=1)
#print('aout', x.shape)
return x
class SelectiveKernelBasic(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, use_se=False, sk_kwargs=None,
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelBasic, self).__init__()
sk_kwargs = sk_kwargs or {}
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
assert base_width == 64, 'BasicBlock doest not support changing base width'
first_planes = planes // reduce_first
outplanes = planes * self.expansion
_selective_first = True # FIXME temporary, for experiments
if _selective_first:
self.conv1 = SelectiveKernelConv(
inplanes, first_planes, stride=stride, dilation=dilation, **sk_kwargs)
else:
self.conv1 = nn.Conv2d(
inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation,
dilation=dilation, bias=False)
self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True)
if _selective_first:
self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
dilation=previous_dilation, bias=False)
else:
self.conv2 = SelectiveKernelConv(
first_planes, outplanes, dilation=previous_dilation, **sk_kwargs)
self.bn2 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.act2 = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
if self.se is not None:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.act2(out)
return out
class SelectiveKernelBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, use_se=False, sk_kwargs=None,
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SelectiveKernelBottleneck, self).__init__()
sk_kwargs = sk_kwargs or {}
width = int(math.floor(planes * (base_width / 64)) * cardinality)
first_planes = width // reduce_first
outplanes = planes * self.expansion
self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
self.bn1 = norm_layer(first_planes)
self.act1 = act_layer(inplace=True)
self.conv2 = SelectiveKernelConv(
first_planes, width, stride=stride, dilation=dilation, groups=cardinality, **sk_kwargs)
self.bn2 = norm_layer(width)
self.act2 = act_layer(inplace=True)
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.act3 = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.act2(out)
out = self.conv3(out)
out = self.bn3(out)
if self.se is not None:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.act3(out)
return out
@register_model
def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-26 model.
"""
default_cfg = default_cfgs['skresnet26d']
sk_kwargs = dict(
keep_3x3=False,
)
model = ResNet(
SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
**kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-18 model.
"""
default_cfg = default_cfgs['skresnet18']
sk_kwargs = dict(
min_attn_channels=16,
)
model = ResNet(
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a ResNet-18 model.
"""
default_cfg = default_cfgs['skresnet18']
sk_kwargs = dict(
min_attn_channels=16,
split_input=True
)
model = ResNet(
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
Loading…
Cancel
Save