You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/sknet.py

238 lines
9.1 KiB

""" Selective Kernel Networks (ResNet base)
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268)
and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer
to the original paper with some modifications of my own to better balance param count vs accuracy.
Hacked together by Ross Wightman
"""
import math
from torch import nn as nn
from .registry import register_model
from .helpers import load_pretrained
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
from .resnet import ResNet
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv1', 'classifier': 'fc',
**kwargs
}
default_cfgs = {
'skresnet18': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'),
'skresnet34': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'),
'skresnet50': _cfg(),
'skresnet50d': _cfg(),
'skresnext50_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth'),
}
class SelectiveKernelBasic(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
super(SelectiveKernelBasic, self).__init__()
sk_kwargs = sk_kwargs or {}
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
assert base_width == 64, 'BasicBlock doest not support changing base width'
first_planes = planes // reduce_first
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
self.conv1 = SelectiveKernelConv(
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None
self.conv2 = ConvBnAct(
first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
self.se = create_attn(attn_layer, outplanes)
self.act = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.conv2.bn.weight)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.conv2(x)
if self.se is not None:
x = self.se(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
x = self.act(x)
return x
class SelectiveKernelBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
super(SelectiveKernelBottleneck, self).__init__()
sk_kwargs = sk_kwargs or {}
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
width = int(math.floor(planes * (base_width / 64)) * cardinality)
first_planes = width // reduce_first
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
self.conv2 = SelectiveKernelConv(
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
**conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None
self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs)
self.se = create_attn(attn_layer, outplanes)
self.act = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
if self.se is not None:
x = self.se(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
x = self.act(x)
return x
@register_model
def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Selective Kernel ResNet-18 model.
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
default_cfg = default_cfgs['skresnet18']
sk_kwargs = dict(
min_attn_channels=16,
attn_reduction=8,
split_input=True
)
model = ResNet(
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def skresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Selective Kernel ResNet-34 model.
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
default_cfg = default_cfgs['skresnet34']
sk_kwargs = dict(
min_attn_channels=16,
attn_reduction=8,
split_input=True
)
model = ResNet(
SelectiveKernelBasic, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Select Kernel ResNet-50 model.
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
sk_kwargs = dict(
split_input=True,
)
default_cfg = default_cfgs['skresnet50']
model = ResNet(
SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Select Kernel ResNet-50-D model.
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
sk_kwargs = dict(
split_input=True,
)
default_cfg = default_cfgs['skresnet50d']
model = ResNet(
SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
the SKNet-50 model in the Select Kernel Paper
"""
default_cfg = default_cfgs['skresnext50_32x4d']
model = ResNet(
SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model