diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 0fa4d210..69e09085 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -16,6 +16,7 @@ from .gluon_xception import * from .res2net import * from .dla import * from .hrnet import * +from .sknet import * from .registry import * from .factory import create_model diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 1d64dcd9..c3104530 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -6,7 +6,6 @@ additional dropout and dynamic global avg/max pool. ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman """ import math -from collections import OrderedDict import torch import torch.nn as nn @@ -101,11 +100,10 @@ default_cfgs = { 'seresnext26tn_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', interpolation='bicubic'), - 'skresnet26d': _cfg() } -def _get_padding(kernel_size, stride, dilation=1): +def get_padding(kernel_size, stride, dilation=1): padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding @@ -234,184 +232,6 @@ class Bottleneck(nn.Module): return out -class SelectiveKernelAttn(nn.Module): - def __init__(self, channels, num_paths=2, num_attn_feat=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelAttn, self).__init__() - self.num_paths = num_paths - self.pool = nn.AdaptiveAvgPool2d(1) - self.fc_reduce = nn.Conv2d(channels, num_attn_feat, kernel_size=1, bias=False) - self.bn = norm_layer(num_attn_feat) - self.act = act_layer(inplace=True) - self.fc_select = nn.Conv2d(num_attn_feat, channels * num_paths, kernel_size=1, bias=False) - - def forward(self, x): - assert x.shape[1] == self.num_paths - x = torch.sum(x, dim=1) - #print('attn sum', x.shape) - x = self.pool(x) - #print('attn pool', x.shape) - x = self.fc_reduce(x) - x = self.bn(x) - x = self.act(x) - x = self.fc_select(x) - #print('attn sel', x.shape) - x = x.view((x.shape[0], self.num_paths, x.shape[1]//self.num_paths) + x.shape[-2:]) - #print('attn spl', x.shape) - x = torch.softmax(x, dim=1) - return x - - -class SelectiveKernelConv(nn.Module): - - def __init__(self, in_channels, out_channels, kernel_size=[3, 5], attn_reduction=16, - min_attn_feat=16, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelConv, self).__init__() - if not isinstance(kernel_size, list): - assert kernel_size >= 3 and kernel_size % 2 - kernel_size = [kernel_size] * 2 - else: - # FIXME assert kernel sizes >=3 and odd - pass - if keep_3x3: - dilation = [dilation * (k - 1) // 2 for k in kernel_size] - kernel_size = [3] * len(kernel_size) - else: - dilation = [dilation] * len(kernel_size) - groups = min(out_channels // len(kernel_size), groups) - - self.conv_paths = nn.ModuleList() - for k, d in zip(kernel_size, dilation): - p = _get_padding(k, stride, d) - self.conv_paths.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d( - in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)), - ('bn', norm_layer(out_channels)), - ('act', act_layer(inplace=True)) - ]))) - - if use_attn: - num_attn_feat = max(int(out_channels / attn_reduction), min_attn_feat) - self.attn = SelectiveKernelAttn(out_channels, len(kernel_size), num_attn_feat) - else: - self.attn = None - - def forward(self, x): - x_paths = [] - for conv in self.conv_paths: - xk = conv(x) - x_paths.append(xk) - if self.attn is not None: - x_paths = torch.stack(x_paths, dim=1) - # print('paths', x_paths.shape) - x_attn = self.attn(x_paths) - #print('attn', x_attn.shape) - x = x_paths * x_attn - #print('amul', x.shape) - x = torch.sum(x, dim=1) - #print('asum', x.shape) - else: - x = torch.cat(x_paths, dim=1) - return x - - -class SelectiveKernelBasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelBasicBlock, self).__init__() - - assert cardinality == 1, 'BasicBlock only supports cardinality of 1' - assert base_width == 64, 'BasicBlock doest not support changing base width' - first_planes = planes // reduce_first - outplanes = planes * self.expansion - - self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, bias=False) - self.bn1 = norm_layer(first_planes) - self.act1 = act_layer(inplace=True) - self.conv2 = SelectiveKernelConv(first_planes, outplanes, dilation=previous_dilation) - self.bn2 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.act2 = act_layer(inplace=True) - self.downsample = downsample - self.stride = stride - self.dilation = dilation - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.bn2(out) - - if self.se is not None: - out = self.se(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.act2(out) - - return out - - -class SelectiveKernelBottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelBottleneck, self).__init__() - - width = int(math.floor(planes * (base_width / 64)) * cardinality) - first_planes = width // reduce_first - outplanes = planes * self.expansion - - self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) - self.bn1 = norm_layer(first_planes) - self.act1 = act_layer(inplace=True) - self.conv2 = SelectiveKernelConv( - first_planes, width, stride=stride, dilation=dilation, groups=cardinality) - self.bn2 = norm_layer(width) - self.act2 = act_layer(inplace=True) - self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) - self.bn3 = norm_layer(outplanes) - self.act3 = act_layer(inplace=True) - self.downsample = downsample - self.stride = stride - self.dilation = dilation - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.act2(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.act3(out) - - return out - - class ResNet(nn.Module): """ResNet / ResNeXt / SE-ResNeXt / SE-Net @@ -560,7 +380,7 @@ class ResNet(nn.Module): downsample = None down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size if stride != 1 or self.inplanes != planes * block.expansion: - downsample_padding = _get_padding(down_kernel_size, stride) + downsample_padding = get_padding(down_kernel_size, stride) downsample_layers = [] conv_stride = stride if avg_down: @@ -628,18 +448,6 @@ def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model -@register_model -def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNet-18 model. - """ - default_cfg = default_cfgs['resnet18'] - model = ResNet(SelectiveKernelBasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - @register_model def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-34 model. @@ -664,19 +472,6 @@ def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model -@register_model -def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNet-26 model. - """ - default_cfg = default_cfgs['skresnet26d'] - model = ResNet( - SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - @register_model def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-26 v1d model. diff --git a/timm/models/sknet.py b/timm/models/sknet.py new file mode 100644 index 00000000..4bc2061d --- /dev/null +++ b/timm/models/sknet.py @@ -0,0 +1,294 @@ +import math +from collections import OrderedDict + +import torch +from torch import nn as nn + +from timm.models.registry import register_model +from timm.models.helpers import load_pretrained +from timm.models.resnet import ResNet, get_padding, SEModule +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'skresnet18': _cfg(url=''), + 'skresnet26d': _cfg() +} + + +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + #print('attn sum', x.shape) + x = self.pool(x) + #print('attn pool', x.shape) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + #print('attn sel', x.shape) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + #print('attn spl', x.shape) + x = torch.softmax(x, dim=1) + return x + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=[3, 5], stride=1, dilation=1, groups=1, + attn_reduction=16, min_attn_channels=32, keep_3x3=True, use_attn=True, + split_input=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + num_paths = len(kernel_size) + self.num_paths = num_paths + self.split_input = split_input + self.in_channels = in_channels + self.out_channels = out_channels + if split_input: + assert in_channels % num_paths == 0 and out_channels % num_paths == 0 + in_channels = in_channels // num_paths + out_channels = out_channels // num_paths + groups = min(out_channels, groups) + + self.paths = nn.ModuleList() + for k, d in zip(kernel_size, dilation): + p = get_padding(k, stride, d) + self.paths.append(nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)), + ('bn', norm_layer(out_channels)), + ('act', act_layer(inplace=True)) + ]))) + + if use_attn: + attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) + else: + self.attn = None + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.out_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + + if self.attn is not None: + x = torch.stack(x_paths, dim=1) + # print('paths', x_paths.shape) + x_attn = self.attn(x) + #print('attn', x_attn.shape) + x = x * x_attn + #print('amul', x.shape) + + if self.split_input: + B, N, C, H, W = x.shape + x = x.reshape(B, N * C, H, W) + else: + x = torch.sum(x, dim=1) + #print('aout', x.shape) + return x + + +class SelectiveKernelBasic(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, sk_kwargs=None, + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBasic, self).__init__() + + sk_kwargs = sk_kwargs or {} + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock doest not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + + _selective_first = True # FIXME temporary, for experiments + if _selective_first: + self.conv1 = SelectiveKernelConv( + inplanes, first_planes, stride=stride, dilation=dilation, **sk_kwargs) + else: + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, + dilation=dilation, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + if _selective_first: + self.conv2 = nn.Conv2d( + first_planes, outplanes, kernel_size=3, padding=previous_dilation, + dilation=previous_dilation, bias=False) + else: + self.conv2 = SelectiveKernelConv( + first_planes, outplanes, dilation=previous_dilation, **sk_kwargs) + self.bn2 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.act2 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act2(out) + + return out + + +class SelectiveKernelBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, sk_kwargs=None, + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBottleneck, self).__init__() + + sk_kwargs = sk_kwargs or {} + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + self.conv2 = SelectiveKernelConv( + first_planes, width, stride=stride, dilation=dilation, groups=cardinality, **sk_kwargs) + self.bn2 = norm_layer(width) + self.act2 = act_layer(inplace=True) + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act2(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act3(out) + + return out + + +@register_model +def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-26 model. + """ + default_cfg = default_cfgs['skresnet26d'] + sk_kwargs = dict( + keep_3x3=False, + ) + model = ResNet( + SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), + **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + split_input=True + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model \ No newline at end of file