diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 0fa4d210..69e09085 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -16,6 +16,7 @@ from .gluon_xception import * from .res2net import * from .dla import * from .hrnet import * +from .sknet import * from .registry import * from .factory import create_model diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py index acd14fde..feaf653c 100644 --- a/timm/models/conv2d_layers.py +++ b/timm/models/conv2d_layers.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import torch import torch.nn as nn import torch.nn.functional as F @@ -100,14 +102,11 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) -class MixedConv2d(nn.Module): +class MixedConv2d(nn.ModuleDict): """ Mixed Grouped Convolution Based on MDConv and GroupedConv in MixNet impl: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py - - NOTE: This does not currently work with torch.jit.script """ - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', dilation=1, depthwise=False, **kwargs): super(MixedConv2d, self).__init__() @@ -131,7 +130,7 @@ class MixedConv2d(nn.Module): def forward(self, x): x_split = torch.split(x, self.splits, 1) - x_out = [c(x) for x, c in zip(x_split, self._modules.values())] + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] x = torch.cat(x_out, 1) return x @@ -240,6 +239,110 @@ class CondConv2d(nn.Module): return out +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + x = self.pool(x) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class ConvBnAct(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(ConvBnAct, self).__init__() + padding = _get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block + self.conv = nn.Conv2d( + in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=False) + self.bn = norm_layer(out_channels) + self.drop_block = drop_block + if act_layer is not None: + self.act = act_layer(inplace=True) + else: + self.act = None + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.drop_block is not None: + x = self.drop_block(x) + if self.act is not None: + x = self.act(x) + return x + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, + attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + kernel_size = kernel_size or [3, 5] + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + self.num_paths = len(kernel_size) + self.in_channels = in_channels + self.out_channels = out_channels + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths + groups = min(out_channels, groups) + + conv_kwargs = dict( + stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) + self.paths = nn.ModuleList([ + ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) + + attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) + self.drop_block = drop_block + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + x = torch.sum(x, dim=1) + return x + + # helper method def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): assert 'groups' not in kwargs # only use 'depthwise' bool arg @@ -256,5 +359,3 @@ def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): else: m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) return m - - diff --git a/timm/models/nn_ops.py b/timm/models/nn_ops.py new file mode 100644 index 00000000..30b98427 --- /dev/null +++ b/timm/models/nn_ops.py @@ -0,0 +1,77 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + + +def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False): + _, _, height, width = x.shape + total_size = width * height + clipped_block_size = min(block_size, min(width, height)) + # seed_drop_rate, the gamma parameter + seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (width - block_size + 1) * + (height - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, height, width)).float() + + uniform_noise = torch.rand_like(x, dtype=torch.float32) + block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float() + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2) + + if drop_with_noise: + normal_noise = torch.randn_like(x) + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7) + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + def __init__(self, + drop_prob=0.1, + block_size=7, + gamma_scale=1.0, + with_noise=False): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + return drop_block_2d(x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise) + + +def drop_path(x, drop_prob=0.): + """Drop paths (Stochastic Depth) per sample (when applied in residual blocks).""" + keep_prob = 1 - drop_prob + random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.ModuleDict): + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + return drop_path(x, self.drop_prob) diff --git a/timm/models/res2net.py b/timm/models/res2net.py index da20e7a0..bcb7eaaf 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -54,14 +54,15 @@ class Bottle2neck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=26, scale=4, use_se=False, - act_layer=nn.ReLU, norm_layer=None, dilation=1, previous_dilation=1, **_): + act_layer=nn.ReLU, norm_layer=None, dilation=1, first_dilation=None, **_): super(Bottle2neck, self).__init__() self.scale = scale self.is_first = stride > 1 or downsample is not None self.num_scales = max(1, scale - 1) width = int(math.floor(planes * (base_width / 64.0))) * cardinality - outplanes = planes * self.expansion self.width = width + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) self.bn1 = norm_layer(width * scale) @@ -70,8 +71,8 @@ class Bottle2neck(nn.Module): bns = [] for i in range(self.num_scales): convs.append(nn.Conv2d( - width, width, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, groups=cardinality, bias=False)) + width, width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False)) bns.append(norm_layer(width)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) @@ -86,6 +87,9 @@ class Bottle2neck(nn.Module): self.relu = act_layer(inplace=True) self.downsample = downsample + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + def forward(self, x): residual = x diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 893350ef..a1c593ae 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -15,6 +15,7 @@ from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d from .layers import EcaModule +from .nn_ops import DropBlock2d, DropPath from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -108,7 +109,7 @@ default_cfgs = { } -def _get_padding(kernel_size, stride, dilation=1): +def get_padding(kernel_size, stride, dilation=1): padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding @@ -136,115 +137,135 @@ class BasicBlock(nn.Module): __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, use_eca = False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_block=None, drop_path=None): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, bias=False) + inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=previous_dilation, - dilation=previous_dilation, bias=False) + first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) self.bn2 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None - self.eca = EcaModule(outplanes) if use_eca else None self.act2 = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) def forward(self, x): residual = x - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.bn2(out) + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) if self.se is not None: - out = self.se(out) - if self.eca is not None: - out = self.eca(out) + x = self.se(x) - if self.downsample is not None: - residual = self.downsample(x) + if self.drop_path is not None: + x = self.drop_path(x) - out += residual - out = self.act2(out) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act2(x) - return out + return x class Bottleneck(nn.Module): __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, use_eca=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_block=None, drop_path=None): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( first_planes, width, kernel_size=3, stride=stride, - padding=dilation, dilation=dilation, groups=cardinality, bias=False) + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None - self.eca = EcaModule(outplanes) if use_eca else None - + self.act3 = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) def forward(self, x): residual = x - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) - out = self.conv2(out) - out = self.bn2(out) - out = self.act2(out) + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act2(x) - out = self.conv3(out) - out = self.bn3(out) + x = self.conv3(x) + x = self.bn3(x) + if self.drop_block is not None: + x = self.drop_block(x) if self.se is not None: - out = self.se(out) - if self.eca is not None: - out = self.eca(out) + x = self.se(x) - if self.downsample is not None: - residual = self.downsample(x) + if self.drop_path is not None: + x = self.drop_path(x) - out += residual - out = self.act3(out) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act3(x) - return out + return x class ResNet(nn.Module): @@ -290,8 +311,6 @@ class ResNet(nn.Module): Number of input (color) channels. use_se : bool, default False Enable Squeeze-Excitation module in blocks - use_eca : bool, default False - Enable ECA module in blocks cardinality : int, default 1 Number of convolution groups for 3x3 conv in Bottleneck. base_width : int, default 64 @@ -323,8 +342,8 @@ class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, use_eca=False, cardinality=1, base_width=64, stem_width=64, stem_type='', block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', - zero_init_last_bn=True, block_args=None): + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0., + drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): block_args = block_args or dict() self.num_classes = num_classes deep_stem = 'deep' in stem_type @@ -356,6 +375,9 @@ class ResNet(nn.Module): self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks + dp = DropPath(drop_path_rate) if drop_block_rate else None + db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None + db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4 if output_stride == 16: strides[3] = 1 @@ -367,29 +389,28 @@ class ResNet(nn.Module): assert output_stride == 32 llargs = list(zip(channels, layers, strides, dilations)) lkwargs = dict( - use_se=use_se, use_eca=use_eca, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, - avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args) + use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, + avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) self.layer1 = self._make_layer(block, *llargs[0], **lkwargs) self.layer2 = self._make_layer(block, *llargs[1], **lkwargs) - self.layer3 = self._make_layer(block, *llargs[2], **lkwargs) - self.layer4 = self._make_layer(block, *llargs[3], **lkwargs) + self.layer3 = self._make_layer(block, drop_block=db_3, *llargs[2], **lkwargs) + self.layer4 = self._make_layer(block, drop_block=db_4, *llargs[3], **lkwargs) # Head (Pooling and Classifier) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.num_features = 512 * block.expansion self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - last_bn_name = 'bn3' if 'Bottle' in block.__name__ else 'bn2' for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - if zero_init_last_bn and 'layer' in n and last_bn_name in n: - # Initialize weight/gamma of last BN in each residual block to zero - nn.init.constant_(m.weight, 0.) - else: - nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, use_se=False, use_eca=False,avg_down=False, down_kernel_size=1, **kwargs): @@ -397,7 +418,7 @@ class ResNet(nn.Module): downsample = None down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size if stride != 1 or self.inplanes != planes * block.expansion: - downsample_padding = _get_padding(down_kernel_size, stride) + downsample_padding = get_padding(down_kernel_size, stride) downsample_layers = [] conv_stride = stride if avg_down: @@ -413,13 +434,10 @@ class ResNet(nn.Module): first_dilation = 1 if dilation in (1, 2) else 2 bkwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - use_se=use_se, use_eca=use_eca, **kwargs) - layers = [block( - self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)] + dilation=dilation, use_se=use_se, **kwargs) + layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **bkwargs)] self.inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block( - self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bkwargs)) + layers += [block(self.inplanes, planes, **bkwargs) for _ in range(1, blocks)] return nn.Sequential(*layers) @@ -447,8 +465,8 @@ class ResNet(nn.Module): def forward(self, x): x = self.forward_features(x) x = self.global_pool(x).flatten(1) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = self.fc(x) return x diff --git a/timm/models/sknet.py b/timm/models/sknet.py new file mode 100644 index 00000000..4b02d501 --- /dev/null +++ b/timm/models/sknet.py @@ -0,0 +1,242 @@ +import math + +from torch import nn as nn + +from timm.models.registry import register_model +from timm.models.helpers import load_pretrained +from timm.models.conv2d_layers import SelectiveKernelConv, ConvBnAct +from timm.models.resnet import ResNet, SEModule +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'skresnet18': _cfg(url=''), + 'skresnet26d': _cfg(), + 'skresnet50': _cfg(), + 'skresnet50d': _cfg(), + 'skresnext50_32x4d': _cfg(), +} + + +class SelectiveKernelBasic(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + use_se=False, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, + drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBasic, self).__init__() + + sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock doest not support changing base width' + first_planes = planes // reduce_first + out_planes = planes * self.expansion + first_dilation = first_dilation or dilation + + _selective_first = True # FIXME temporary, for experiments + if _selective_first: + self.conv1 = SelectiveKernelConv( + inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv2 = ConvBnAct( + first_planes, out_planes, kernel_size=3, dilation=dilation, **conv_kwargs) + else: + self.conv1 = ConvBnAct( + inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs) + conv_kwargs['act_layer'] = None + self.conv2 = SelectiveKernelConv( + first_planes, out_planes, dilation=dilation, **conv_kwargs, **sk_kwargs) + self.se = SEModule(out_planes, planes // 4) if use_se else None + self.act = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv2.bn.weight) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.conv2(x) + if self.se is not None: + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act(x) + return x + + +class SelectiveKernelBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, sk_kwargs=None, + reduce_first=1, dilation=1, first_dilation=None, + drop_block=None, drop_path=None, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBottleneck, self).__init__() + + sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + out_planes = planes * self.expansion + first_dilation = first_dilation or dilation + + self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) + self.conv2 = SelectiveKernelConv( + first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, + **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv3 = ConvBnAct(width, out_planes, kernel_size=1, **conv_kwargs) + self.se = SEModule(out_planes, planes // 4) if use_se else None + self.act = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + if self.se is not None: + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act(x) + return x + + +@register_model +def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + attn_reduction=8, + split_input=True + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-26 model. + """ + default_cfg = default_cfgs['skresnet26d'] + sk_kwargs = dict( + keep_3x3=False, + ) + model = ResNet( + SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False + **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Select Kernel ResNet-50 model. + Based on config in "Compounding the Performance Improvements of Assembled Techniques in a + Convolutional Neural Network" + """ + sk_kwargs = dict( + attn_reduction=2, + ) + default_cfg = default_cfgs['skresnet50'] + model = ResNet( + SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Select Kernel ResNet-50-D model. + Based on config in "Compounding the Performance Improvements of Assembled Techniques in a + Convolutional Neural Network" + """ + sk_kwargs = dict( + attn_reduction=2, + ) + default_cfg = default_cfgs['skresnet50d'] + model = ResNet( + SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to + the SKNet50 model in the Select Kernel Paper + """ + default_cfg = default_cfgs['skresnext50_32x4d'] + model = ResNet( + SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, + num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model