""" Selective Kernel Networks (ResNet base) Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586) This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268) and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer to the original paper with some modifications of my own to better balance param count vs accuracy. Hacked together by Ross Wightman """ import math from torch import nn as nn from .registry import register_model from .helpers import load_pretrained from .layers import SelectiveKernelConv, ConvBnAct, create_attn from .resnet import ResNet from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'conv1', 'classifier': 'fc', **kwargs } default_cfgs = { 'skresnet18': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'), 'skresnet34': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'), 'skresnet50': _cfg(), 'skresnet50d': _cfg(), 'skresnext50_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth'), } class SelectiveKernelBasic(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): super(SelectiveKernelBasic, self).__init__() sk_kwargs = sk_kwargs or {} conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation self.conv1 = SelectiveKernelConv( inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) conv_kwargs['act_layer'] = None self.conv2 = ConvBnAct( first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs) self.se = create_attn(attn_layer, outplanes) self.act = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation self.drop_block = drop_block self.drop_path = drop_path def zero_init_last_bn(self): nn.init.zeros_(self.conv2.bn.weight) def forward(self, x): residual = x x = self.conv1(x) x = self.conv2(x) if self.se is not None: x = self.se(x) if self.drop_path is not None: x = self.drop_path(x) if self.downsample is not None: residual = self.downsample(residual) x += residual x = self.act(x) return x class SelectiveKernelBottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): super(SelectiveKernelBottleneck, self).__init__() sk_kwargs = sk_kwargs or {} conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) self.conv2 = SelectiveKernelConv( first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, **conv_kwargs, **sk_kwargs) conv_kwargs['act_layer'] = None self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs) self.se = create_attn(attn_layer, outplanes) self.act = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation self.drop_block = drop_block self.drop_path = drop_path def zero_init_last_bn(self): nn.init.zeros_(self.conv3.bn.weight) def forward(self, x): residual = x x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) if self.se is not None: x = self.se(x) if self.drop_path is not None: x = self.drop_path(x) if self.downsample is not None: residual = self.downsample(residual) x += residual x = self.act(x) return x @register_model def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a Selective Kernel ResNet-18 model. Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ default_cfg = default_cfgs['skresnet18'] sk_kwargs = dict( min_attn_channels=16, attn_reduction=8, split_input=True ) model = ResNet( SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model @register_model def skresnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a Selective Kernel ResNet-34 model. Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ default_cfg = default_cfgs['skresnet34'] sk_kwargs = dict( min_attn_channels=16, attn_reduction=8, split_input=True ) model = ResNet( SelectiveKernelBasic, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model @register_model def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a Select Kernel ResNet-50 model. Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ sk_kwargs = dict( split_input=True, ) default_cfg = default_cfgs['skresnet50'] model = ResNet( SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model @register_model def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a Select Kernel ResNet-50-D model. Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ sk_kwargs = dict( split_input=True, ) default_cfg = default_cfgs['skresnet50d'] model = ResNet( SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model @register_model def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to the SKNet-50 model in the Select Kernel Paper """ default_cfg = default_cfgs['skresnext50_32x4d'] model = ResNet( SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model