From c8b3d6b81a478ec72b8d5f75015b3859af926df1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jan 2020 19:45:05 -0800 Subject: [PATCH 01/13] Initial impl of Selective Kernel Networks. Very much a WIP. --- timm/models/resnet.py | 146 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 422eb0cb..b3adf6dc 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -6,6 +6,7 @@ additional dropout and dynamic global avg/max pool. ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman """ import math +from collections import OrderedDict import torch import torch.nn as nn @@ -100,6 +101,7 @@ default_cfgs = { 'seresnext26tn_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', interpolation='bicubic'), + 'skresnet26d': _cfg() } @@ -232,6 +234,137 @@ class Bottleneck(nn.Module): return out +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, num_attn_feat=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, num_attn_feat, kernel_size=1, bias=False) + self.bn = norm_layer(num_attn_feat) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(num_attn_feat, channels * num_paths, kernel_size=1) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + #print('attn sum', x.shape) + x = self.pool(x) + #print('attn pool', x.shape) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + #print('attn sel', x.shape) + x = x.view((x.shape[0], self.num_paths, x.shape[1]//self.num_paths) + x.shape[-2:]) + #print('attn spl', x.shape) + x = torch.softmax(x, dim=1) + return x + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=[3, 5], attn_reduction=16, + min_attn_feat=32, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + if not isinstance(kernel_size, list): + assert kernel_size >= 3 and kernel_size % 2 + kernel_size = [kernel_size] * 2 + else: + # FIXME assert kernel sizes >=3 and odd + pass + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + groups = min(out_channels // len(kernel_size), groups) + + self.conv_paths = nn.ModuleList() + for k, d in zip(kernel_size, dilation): + p = _get_padding(k, stride, d) + self.conv_paths.append(nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)), + ('bn', norm_layer(out_channels)), + ('act', act_layer(inplace=True)) + ]))) + + if use_attn: + num_attn_feat = max(int(out_channels / attn_reduction), min_attn_feat) + self.attn = SelectiveKernelAttn(out_channels, len(kernel_size), num_attn_feat) + else: + self.attn = None + + def forward(self, x): + x_paths = [] + for conv in self.conv_paths: + xk = conv(x) + x_paths.append(xk) + if self.attn is not None: + x_paths = torch.stack(x_paths, dim=1) + # print('paths', x_paths.shape) + x_attn = self.attn(x_paths) + #print('attn', x_attn.shape) + x = x_paths * x_attn + #print('amul', x.shape) + x = torch.sum(x, dim=1) + #print('asum', x.shape) + else: + x = torch.cat(x_paths, dim=1) + return x + + +class SelectiveKernelBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBottleneck, self).__init__() + + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + self.conv2 = SelectiveKernelConv( + first_planes, width, stride=stride, dilation=dilation, groups=cardinality) + self.bn2 = norm_layer(width) + self.act2 = act_layer(inplace=True) + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act2(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act3(out) + + return out + + class ResNet(nn.Module): """ResNet / ResNeXt / SE-ResNeXt / SE-Net @@ -472,6 +605,19 @@ def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-26 model. + """ + default_cfg = default_cfgs['skresnet26d'] + model = ResNet( + SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + @register_model def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-26 v1d model. From ad087b4b1785d5a6f0d2404b8a30ee9ee7add8f0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jan 2020 19:54:37 -0800 Subject: [PATCH 02/13] Missed bias=False in selection conv --- timm/models/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index b3adf6dc..57a20894 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -243,7 +243,7 @@ class SelectiveKernelAttn(nn.Module): self.fc_reduce = nn.Conv2d(channels, num_attn_feat, kernel_size=1, bias=False) self.bn = norm_layer(num_attn_feat) self.act = act_layer(inplace=True) - self.fc_select = nn.Conv2d(num_attn_feat, channels * num_paths, kernel_size=1) + self.fc_select = nn.Conv2d(num_attn_feat, channels * num_paths, kernel_size=1, bias=False) def forward(self, x): assert x.shape[1] == self.num_paths From a93bae6dc5a8831f1633208ac45d598a225810c5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 25 Jan 2020 18:31:08 -0800 Subject: [PATCH 03/13] A SelectiveKernelBasicBlock for more experiments --- timm/models/resnet.py | 61 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 57a20894..1d64dcd9 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -265,7 +265,7 @@ class SelectiveKernelAttn(nn.Module): class SelectiveKernelConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=[3, 5], attn_reduction=16, - min_attn_feat=32, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True, + min_attn_feat=16, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelConv, self).__init__() if not isinstance(kernel_size, list): @@ -316,6 +316,53 @@ class SelectiveKernelConv(nn.Module): return x +class SelectiveKernelBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBasicBlock, self).__init__() + + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock doest not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, + dilation=dilation, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + self.conv2 = SelectiveKernelConv(first_planes, outplanes, dilation=previous_dilation) + self.bn2 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.act2 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act2(out) + + return out + + class SelectiveKernelBottleneck(nn.Module): expansion = 4 @@ -581,6 +628,18 @@ def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['resnet18'] + model = ResNet(SelectiveKernelBasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-34 model. From 58e28dc7e7a988fe3566aa15150cfc76261eccc7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 26 Jan 2020 11:23:39 -0800 Subject: [PATCH 04/13] Move Selective Kernel blocks/convs to their own sknet.py file --- timm/models/__init__.py | 1 + timm/models/resnet.py | 209 +--------------------------- timm/models/sknet.py | 294 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 297 insertions(+), 207 deletions(-) create mode 100644 timm/models/sknet.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 0fa4d210..69e09085 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -16,6 +16,7 @@ from .gluon_xception import * from .res2net import * from .dla import * from .hrnet import * +from .sknet import * from .registry import * from .factory import create_model diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 1d64dcd9..c3104530 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -6,7 +6,6 @@ additional dropout and dynamic global avg/max pool. ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman """ import math -from collections import OrderedDict import torch import torch.nn as nn @@ -101,11 +100,10 @@ default_cfgs = { 'seresnext26tn_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', interpolation='bicubic'), - 'skresnet26d': _cfg() } -def _get_padding(kernel_size, stride, dilation=1): +def get_padding(kernel_size, stride, dilation=1): padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding @@ -234,184 +232,6 @@ class Bottleneck(nn.Module): return out -class SelectiveKernelAttn(nn.Module): - def __init__(self, channels, num_paths=2, num_attn_feat=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelAttn, self).__init__() - self.num_paths = num_paths - self.pool = nn.AdaptiveAvgPool2d(1) - self.fc_reduce = nn.Conv2d(channels, num_attn_feat, kernel_size=1, bias=False) - self.bn = norm_layer(num_attn_feat) - self.act = act_layer(inplace=True) - self.fc_select = nn.Conv2d(num_attn_feat, channels * num_paths, kernel_size=1, bias=False) - - def forward(self, x): - assert x.shape[1] == self.num_paths - x = torch.sum(x, dim=1) - #print('attn sum', x.shape) - x = self.pool(x) - #print('attn pool', x.shape) - x = self.fc_reduce(x) - x = self.bn(x) - x = self.act(x) - x = self.fc_select(x) - #print('attn sel', x.shape) - x = x.view((x.shape[0], self.num_paths, x.shape[1]//self.num_paths) + x.shape[-2:]) - #print('attn spl', x.shape) - x = torch.softmax(x, dim=1) - return x - - -class SelectiveKernelConv(nn.Module): - - def __init__(self, in_channels, out_channels, kernel_size=[3, 5], attn_reduction=16, - min_attn_feat=16, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelConv, self).__init__() - if not isinstance(kernel_size, list): - assert kernel_size >= 3 and kernel_size % 2 - kernel_size = [kernel_size] * 2 - else: - # FIXME assert kernel sizes >=3 and odd - pass - if keep_3x3: - dilation = [dilation * (k - 1) // 2 for k in kernel_size] - kernel_size = [3] * len(kernel_size) - else: - dilation = [dilation] * len(kernel_size) - groups = min(out_channels // len(kernel_size), groups) - - self.conv_paths = nn.ModuleList() - for k, d in zip(kernel_size, dilation): - p = _get_padding(k, stride, d) - self.conv_paths.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d( - in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)), - ('bn', norm_layer(out_channels)), - ('act', act_layer(inplace=True)) - ]))) - - if use_attn: - num_attn_feat = max(int(out_channels / attn_reduction), min_attn_feat) - self.attn = SelectiveKernelAttn(out_channels, len(kernel_size), num_attn_feat) - else: - self.attn = None - - def forward(self, x): - x_paths = [] - for conv in self.conv_paths: - xk = conv(x) - x_paths.append(xk) - if self.attn is not None: - x_paths = torch.stack(x_paths, dim=1) - # print('paths', x_paths.shape) - x_attn = self.attn(x_paths) - #print('attn', x_attn.shape) - x = x_paths * x_attn - #print('amul', x.shape) - x = torch.sum(x, dim=1) - #print('asum', x.shape) - else: - x = torch.cat(x_paths, dim=1) - return x - - -class SelectiveKernelBasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelBasicBlock, self).__init__() - - assert cardinality == 1, 'BasicBlock only supports cardinality of 1' - assert base_width == 64, 'BasicBlock doest not support changing base width' - first_planes = planes // reduce_first - outplanes = planes * self.expansion - - self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, bias=False) - self.bn1 = norm_layer(first_planes) - self.act1 = act_layer(inplace=True) - self.conv2 = SelectiveKernelConv(first_planes, outplanes, dilation=previous_dilation) - self.bn2 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.act2 = act_layer(inplace=True) - self.downsample = downsample - self.stride = stride - self.dilation = dilation - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.bn2(out) - - if self.se is not None: - out = self.se(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.act2(out) - - return out - - -class SelectiveKernelBottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelBottleneck, self).__init__() - - width = int(math.floor(planes * (base_width / 64)) * cardinality) - first_planes = width // reduce_first - outplanes = planes * self.expansion - - self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) - self.bn1 = norm_layer(first_planes) - self.act1 = act_layer(inplace=True) - self.conv2 = SelectiveKernelConv( - first_planes, width, stride=stride, dilation=dilation, groups=cardinality) - self.bn2 = norm_layer(width) - self.act2 = act_layer(inplace=True) - self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) - self.bn3 = norm_layer(outplanes) - self.act3 = act_layer(inplace=True) - self.downsample = downsample - self.stride = stride - self.dilation = dilation - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.act2(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.act3(out) - - return out - - class ResNet(nn.Module): """ResNet / ResNeXt / SE-ResNeXt / SE-Net @@ -560,7 +380,7 @@ class ResNet(nn.Module): downsample = None down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size if stride != 1 or self.inplanes != planes * block.expansion: - downsample_padding = _get_padding(down_kernel_size, stride) + downsample_padding = get_padding(down_kernel_size, stride) downsample_layers = [] conv_stride = stride if avg_down: @@ -628,18 +448,6 @@ def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model -@register_model -def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNet-18 model. - """ - default_cfg = default_cfgs['resnet18'] - model = ResNet(SelectiveKernelBasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - @register_model def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-34 model. @@ -664,19 +472,6 @@ def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model -@register_model -def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNet-26 model. - """ - default_cfg = default_cfgs['skresnet26d'] - model = ResNet( - SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - @register_model def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-26 v1d model. diff --git a/timm/models/sknet.py b/timm/models/sknet.py new file mode 100644 index 00000000..4bc2061d --- /dev/null +++ b/timm/models/sknet.py @@ -0,0 +1,294 @@ +import math +from collections import OrderedDict + +import torch +from torch import nn as nn + +from timm.models.registry import register_model +from timm.models.helpers import load_pretrained +from timm.models.resnet import ResNet, get_padding, SEModule +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'skresnet18': _cfg(url=''), + 'skresnet26d': _cfg() +} + + +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + #print('attn sum', x.shape) + x = self.pool(x) + #print('attn pool', x.shape) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + #print('attn sel', x.shape) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + #print('attn spl', x.shape) + x = torch.softmax(x, dim=1) + return x + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=[3, 5], stride=1, dilation=1, groups=1, + attn_reduction=16, min_attn_channels=32, keep_3x3=True, use_attn=True, + split_input=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + num_paths = len(kernel_size) + self.num_paths = num_paths + self.split_input = split_input + self.in_channels = in_channels + self.out_channels = out_channels + if split_input: + assert in_channels % num_paths == 0 and out_channels % num_paths == 0 + in_channels = in_channels // num_paths + out_channels = out_channels // num_paths + groups = min(out_channels, groups) + + self.paths = nn.ModuleList() + for k, d in zip(kernel_size, dilation): + p = get_padding(k, stride, d) + self.paths.append(nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)), + ('bn', norm_layer(out_channels)), + ('act', act_layer(inplace=True)) + ]))) + + if use_attn: + attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) + else: + self.attn = None + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.out_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + + if self.attn is not None: + x = torch.stack(x_paths, dim=1) + # print('paths', x_paths.shape) + x_attn = self.attn(x) + #print('attn', x_attn.shape) + x = x * x_attn + #print('amul', x.shape) + + if self.split_input: + B, N, C, H, W = x.shape + x = x.reshape(B, N * C, H, W) + else: + x = torch.sum(x, dim=1) + #print('aout', x.shape) + return x + + +class SelectiveKernelBasic(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, sk_kwargs=None, + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBasic, self).__init__() + + sk_kwargs = sk_kwargs or {} + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock doest not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + + _selective_first = True # FIXME temporary, for experiments + if _selective_first: + self.conv1 = SelectiveKernelConv( + inplanes, first_planes, stride=stride, dilation=dilation, **sk_kwargs) + else: + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, + dilation=dilation, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + if _selective_first: + self.conv2 = nn.Conv2d( + first_planes, outplanes, kernel_size=3, padding=previous_dilation, + dilation=previous_dilation, bias=False) + else: + self.conv2 = SelectiveKernelConv( + first_planes, outplanes, dilation=previous_dilation, **sk_kwargs) + self.bn2 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.act2 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act2(out) + + return out + + +class SelectiveKernelBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, sk_kwargs=None, + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBottleneck, self).__init__() + + sk_kwargs = sk_kwargs or {} + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + self.conv2 = SelectiveKernelConv( + first_planes, width, stride=stride, dilation=dilation, groups=cardinality, **sk_kwargs) + self.bn2 = norm_layer(width) + self.act2 = act_layer(inplace=True) + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act2(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act3(out) + + return out + + +@register_model +def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-26 model. + """ + default_cfg = default_cfgs['skresnet26d'] + sk_kwargs = dict( + keep_3x3=False, + ) + model = ResNet( + SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), + **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + split_input=True + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model \ No newline at end of file From 9abe6109316bca240ca71a32806b1e6e39979ce0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 26 Jan 2020 11:33:31 -0800 Subject: [PATCH 05/13] Used wrong channel var for split --- timm/models/sknet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 4bc2061d..d3a4fb5d 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -106,7 +106,7 @@ class SelectiveKernelConv(nn.Module): def forward(self, x): if self.split_input: - x_split = torch.split(x, self.out_channels // self.num_paths, 1) + x_split = torch.split(x, self.in_channels // self.num_paths, 1) x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] else: x_paths = [op(x) for op in self.paths] From cefc9b7761584f7118e65a6bbdcc3887f3b0de62 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 27 Jan 2020 21:48:28 -0800 Subject: [PATCH 06/13] Move SelectKernelConv to conv2d_layers and more * always apply attention in SelectKernelConv, leave MixedConv for no attention alternative * make MixedConv torchscript compatible * refactor first/previous dilation name to make more sense in ResNet* networks --- timm/models/conv2d_layers.py | 102 ++++++++++++++++++++++++++-- timm/models/res2net.py | 9 +-- timm/models/resnet.py | 33 +++++---- timm/models/sknet.py | 126 ++++------------------------------- 4 files changed, 128 insertions(+), 142 deletions(-) diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py index acd14fde..7583263a 100644 --- a/timm/models/conv2d_layers.py +++ b/timm/models/conv2d_layers.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import torch import torch.nn as nn import torch.nn.functional as F @@ -100,14 +102,11 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) -class MixedConv2d(nn.Module): +class MixedConv2d(nn.ModuleDict): """ Mixed Grouped Convolution Based on MDConv and GroupedConv in MixNet impl: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py - - NOTE: This does not currently work with torch.jit.script """ - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', dilation=1, depthwise=False, **kwargs): super(MixedConv2d, self).__init__() @@ -131,7 +130,7 @@ class MixedConv2d(nn.Module): def forward(self, x): x_split = torch.split(x, self.splits, 1) - x_out = [c(x) for x, c in zip(x_split, self._modules.values())] + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] x = torch.cat(x_out, 1) return x @@ -240,6 +239,97 @@ class CondConv2d(nn.Module): return out +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + x = self.pool(x) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, + attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + kernel_size = kernel_size or [3, 5] + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + num_paths = len(kernel_size) + self.num_paths = num_paths + self.split_input = split_input + self.in_channels = in_channels + self.out_channels = out_channels + if split_input: + assert in_channels % num_paths == 0 and out_channels % num_paths == 0 + in_channels = in_channels // num_paths + out_channels = out_channels // num_paths + groups = min(out_channels, groups) + + self.paths = nn.ModuleList() + for k, d in zip(kernel_size, dilation): + p = _get_padding(k, stride, d) + self.paths.append(nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + in_channels, out_channels, kernel_size=k, stride=stride, padding=p, + dilation=d, groups=groups, bias=False)), + ('bn', norm_layer(out_channels)), + ('act', act_layer(inplace=True)) + ]))) + + attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + + if self.split_input: + B, N, C, H, W = x.shape + x = x.reshape(B, N * C, H, W) + else: + x = torch.sum(x, dim=1) + return x + + # helper method def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): assert 'groups' not in kwargs # only use 'depthwise' bool arg @@ -256,5 +346,3 @@ def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): else: m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) return m - - diff --git a/timm/models/res2net.py b/timm/models/res2net.py index da20e7a0..c83aba62 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -54,14 +54,15 @@ class Bottle2neck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=26, scale=4, use_se=False, - act_layer=nn.ReLU, norm_layer=None, dilation=1, previous_dilation=1, **_): + act_layer=nn.ReLU, norm_layer=None, dilation=1, first_dilation=None, **_): super(Bottle2neck, self).__init__() self.scale = scale self.is_first = stride > 1 or downsample is not None self.num_scales = max(1, scale - 1) width = int(math.floor(planes * (base_width / 64.0))) * cardinality - outplanes = planes * self.expansion self.width = width + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) self.bn1 = norm_layer(width * scale) @@ -70,8 +71,8 @@ class Bottle2neck(nn.Module): bns = [] for i in range(self.num_scales): convs.append(nn.Conv2d( - width, width, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, groups=cardinality, bias=False)) + width, width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False)) bns.append(norm_layer(width)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index c3104530..976d4234 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -131,24 +131,23 @@ class BasicBlock(nn.Module): __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, bias=False) + inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=previous_dilation, - dilation=previous_dilation, bias=False) + first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) self.bn2 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None self.act2 = act_layer(inplace=True) @@ -181,21 +180,21 @@ class Bottleneck(nn.Module): __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( first_planes, width, kernel_size=3, stride=stride, - padding=dilation, dilation=dilation, groups=cardinality, bias=False) + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) @@ -396,13 +395,11 @@ class ResNet(nn.Module): first_dilation = 1 if dilation in (1, 2) else 2 bkwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - use_se=use_se, **kwargs) - layers = [block( - self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)] + dilation=dilation, use_se=use_se, **kwargs) + layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **bkwargs)] self.inplanes = planes * block.expansion for i in range(1, blocks): - layers.append(block( - self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bkwargs)) + layers.append(block(self.inplanes, planes, **bkwargs)) return nn.Sequential(*layers) @@ -430,8 +427,8 @@ class ResNet(nn.Module): def forward(self, x): x = self.forward_features(x) x = self.global_pool(x).flatten(1) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = self.fc(x) return x diff --git a/timm/models/sknet.py b/timm/models/sknet.py index d3a4fb5d..e9dbf68d 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -1,12 +1,11 @@ import math -from collections import OrderedDict -import torch from torch import nn as nn from timm.models.registry import register_model from timm.models.helpers import load_pretrained -from timm.models.resnet import ResNet, get_padding, SEModule +from timm.models.conv2d_layers import SelectiveKernelConv +from timm.models.resnet import ResNet, SEModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -27,113 +26,12 @@ default_cfgs = { } -class SelectiveKernelAttn(nn.Module): - def __init__(self, channels, num_paths=2, attn_channels=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelAttn, self).__init__() - self.num_paths = num_paths - self.pool = nn.AdaptiveAvgPool2d(1) - self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) - self.bn = norm_layer(attn_channels) - self.act = act_layer(inplace=True) - self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) - - def forward(self, x): - assert x.shape[1] == self.num_paths - x = torch.sum(x, dim=1) - #print('attn sum', x.shape) - x = self.pool(x) - #print('attn pool', x.shape) - x = self.fc_reduce(x) - x = self.bn(x) - x = self.act(x) - x = self.fc_select(x) - #print('attn sel', x.shape) - B, C, H, W = x.shape - x = x.view(B, self.num_paths, C // self.num_paths, H, W) - #print('attn spl', x.shape) - x = torch.softmax(x, dim=1) - return x - - -def _kernel_valid(k): - if isinstance(k, (list, tuple)): - for ki in k: - return _kernel_valid(ki) - assert k >= 3 and k % 2 - - -class SelectiveKernelConv(nn.Module): - - def __init__(self, in_channels, out_channels, kernel_size=[3, 5], stride=1, dilation=1, groups=1, - attn_reduction=16, min_attn_channels=32, keep_3x3=True, use_attn=True, - split_input=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelConv, self).__init__() - _kernel_valid(kernel_size) - if not isinstance(kernel_size, list): - kernel_size = [kernel_size] * 2 - if keep_3x3: - dilation = [dilation * (k - 1) // 2 for k in kernel_size] - kernel_size = [3] * len(kernel_size) - else: - dilation = [dilation] * len(kernel_size) - num_paths = len(kernel_size) - self.num_paths = num_paths - self.split_input = split_input - self.in_channels = in_channels - self.out_channels = out_channels - if split_input: - assert in_channels % num_paths == 0 and out_channels % num_paths == 0 - in_channels = in_channels // num_paths - out_channels = out_channels // num_paths - groups = min(out_channels, groups) - - self.paths = nn.ModuleList() - for k, d in zip(kernel_size, dilation): - p = get_padding(k, stride, d) - self.paths.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d( - in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)), - ('bn', norm_layer(out_channels)), - ('act', act_layer(inplace=True)) - ]))) - - if use_attn: - attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) - self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) - else: - self.attn = None - - def forward(self, x): - if self.split_input: - x_split = torch.split(x, self.in_channels // self.num_paths, 1) - x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] - else: - x_paths = [op(x) for op in self.paths] - - if self.attn is not None: - x = torch.stack(x_paths, dim=1) - # print('paths', x_paths.shape) - x_attn = self.attn(x) - #print('attn', x_attn.shape) - x = x * x_attn - #print('amul', x.shape) - - if self.split_input: - B, N, C, H, W = x.shape - x = x.reshape(B, N * C, H, W) - else: - x = torch.sum(x, dim=1) - #print('aout', x.shape) - return x - - class SelectiveKernelBasic(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelBasic, self).__init__() sk_kwargs = sk_kwargs or {} @@ -141,24 +39,25 @@ class SelectiveKernelBasic(nn.Module): assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation _selective_first = True # FIXME temporary, for experiments if _selective_first: self.conv1 = SelectiveKernelConv( - inplanes, first_planes, stride=stride, dilation=dilation, **sk_kwargs) + inplanes, first_planes, stride=stride, dilation=first_dilation, **sk_kwargs) else: self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, bias=False) + inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) if _selective_first: self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=previous_dilation, - dilation=previous_dilation, bias=False) + first_planes, outplanes, kernel_size=3, padding=dilation, + dilation=dilation, bias=False) else: self.conv2 = SelectiveKernelConv( - first_planes, outplanes, dilation=previous_dilation, **sk_kwargs) + first_planes, outplanes, dilation=dilation, **sk_kwargs) self.bn2 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None self.act2 = act_layer(inplace=True) @@ -192,19 +91,20 @@ class SelectiveKernelBottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelBottleneck, self).__init__() sk_kwargs = sk_kwargs or {} width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = SelectiveKernelConv( - first_planes, width, stride=stride, dilation=dilation, groups=cardinality, **sk_kwargs) + first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, **sk_kwargs) self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) From 9f11b4e8a25495874d84a56d4ca11af191a01324 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 29 Jan 2020 13:01:35 -0800 Subject: [PATCH 07/13] Add ConvBnAct layer to parallel integrated SelectKernelConv, add support for DropPath and DropBlock to ResNet base and SK blocks --- timm/models/conv2d_layers.py | 43 +++++++++---- timm/models/resnet.py | 20 +++--- timm/models/sknet.py | 118 +++++++++++++++-------------------- 3 files changed, 96 insertions(+), 85 deletions(-) diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py index 7583263a..5b1a44e8 100644 --- a/timm/models/conv2d_layers.py +++ b/timm/models/conv2d_layers.py @@ -271,11 +271,36 @@ def _kernel_valid(k): assert k >= 3 and k % 2 +class ConvBnAct(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(ConvBnAct, self).__init__() + padding = _get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block + self.conv = nn.Conv2d( + in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=False) + self.bn = norm_layer(out_channels) + self.drop_block = drop_block + if act_layer is not None: + self.act = act_layer(inplace=True) + else: + self.act = None + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.drop_block is not None: + x = self.drop_block(x) + if self.act is not None: + x = self.act(x) + return x + + class SelectiveKernelConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelConv, self).__init__() kernel_size = kernel_size or [3, 5] _kernel_valid(kernel_size) @@ -297,19 +322,15 @@ class SelectiveKernelConv(nn.Module): out_channels = out_channels // num_paths groups = min(out_channels, groups) - self.paths = nn.ModuleList() - for k, d in zip(kernel_size, dilation): - p = _get_padding(k, stride, d) - self.paths.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d( - in_channels, out_channels, kernel_size=k, stride=stride, padding=p, - dilation=d, groups=groups, bias=False)), - ('bn', norm_layer(out_channels)), - ('act', act_layer(inplace=True)) - ]))) + conv_kwargs = dict( + stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) + self.paths = nn.ModuleList([ + ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) + self.drop_block = drop_block def forward(self, x): if self.split_input: diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 976d4234..3e0ce23e 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .nn_ops import DropBlock2d, DropPath from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -132,7 +133,8 @@ class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_block=None, drop_path=None): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' @@ -181,7 +183,8 @@ class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_block=None, drop_path=None): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) @@ -305,8 +308,8 @@ class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, cardinality=1, base_width=64, stem_width=64, stem_type='', block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', - zero_init_last_bn=True, block_args=None): + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0., + drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): block_args = block_args or dict() self.num_classes = num_classes deep_stem = 'deep' in stem_type @@ -338,6 +341,9 @@ class ResNet(nn.Module): self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks + dp = DropPath(drop_path_rate) if drop_block_rate else None + db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None + db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4 if output_stride == 16: strides[3] = 1 @@ -350,11 +356,11 @@ class ResNet(nn.Module): llargs = list(zip(channels, layers, strides, dilations)) lkwargs = dict( use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, - avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args) + avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) self.layer1 = self._make_layer(block, *llargs[0], **lkwargs) self.layer2 = self._make_layer(block, *llargs[1], **lkwargs) - self.layer3 = self._make_layer(block, *llargs[2], **lkwargs) - self.layer4 = self._make_layer(block, *llargs[3], **lkwargs) + self.layer3 = self._make_layer(block, drop_block=db_3, *llargs[2], **lkwargs) + self.layer4 = self._make_layer(block, drop_block=db_4, *llargs[3], **lkwargs) # Head (Pooling and Classifier) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) diff --git a/timm/models/sknet.py b/timm/models/sknet.py index e9dbf68d..41e19075 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -4,7 +4,7 @@ from torch import nn as nn from timm.models.registry import register_model from timm.models.helpers import load_pretrained -from timm.models.conv2d_layers import SelectiveKernelConv +from timm.models.conv2d_layers import SelectiveKernelConv, ConvBnAct from timm.models.resnet import ResNet, SEModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -29,61 +29,53 @@ default_cfgs = { class SelectiveKernelBasic(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + use_se=False, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, + drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelBasic, self).__init__() sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first - outplanes = planes * self.expansion + out_planes = planes * self.expansion first_dilation = first_dilation or dilation _selective_first = True # FIXME temporary, for experiments if _selective_first: self.conv1 = SelectiveKernelConv( - inplanes, first_planes, stride=stride, dilation=first_dilation, **sk_kwargs) - else: - self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, - dilation=first_dilation, bias=False) - self.bn1 = norm_layer(first_planes) - self.act1 = act_layer(inplace=True) - if _selective_first: - self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=dilation, - dilation=dilation, bias=False) + inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv2 = ConvBnAct( + first_planes, out_planes, kernel_size=3, dilation=dilation, **conv_kwargs) else: + self.conv1 = ConvBnAct( + inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs) + conv_kwargs['act_layer'] = None self.conv2 = SelectiveKernelConv( - first_planes, outplanes, dilation=dilation, **sk_kwargs) - self.bn2 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.act2 = act_layer(inplace=True) + first_planes, out_planes, dilation=dilation, **conv_kwargs, **sk_kwargs) + self.se = SEModule(out_planes, planes // 4) if use_se else None + self.act = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path def forward(self, x): residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.bn2(out) - + x = self.conv1(x) + x = self.conv2(x) if self.se is not None: - out = self.se(out) - + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.act2(out) - - return out + residual = self.downsample(residual) + x += residual + x = self.act(x) + return x class SelectiveKernelBottleneck(nn.Module): @@ -91,54 +83,46 @@ class SelectiveKernelBottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, + drop_block=None, drop_path=None, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelBottleneck, self).__init__() sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first - outplanes = planes * self.expansion + out_planes = planes * self.expansion first_dilation = first_dilation or dilation - self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) - self.bn1 = norm_layer(first_planes) - self.act1 = act_layer(inplace=True) + self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) self.conv2 = SelectiveKernelConv( - first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, **sk_kwargs) - self.bn2 = norm_layer(width) - self.act2 = act_layer(inplace=True) - self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) - self.bn3 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.act3 = act_layer(inplace=True) + first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, + **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv3 = ConvBnAct(width, out_planes, kernel_size=1, **conv_kwargs) + self.se = SEModule(out_planes, planes // 4) if use_se else None + self.act = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path def forward(self, x): residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.act2(out) - - out = self.conv3(out) - out = self.bn3(out) - + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) if self.se is not None: - out = self.se(out) - + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.act3(out) - - return out + residual = self.downsample(residual) + x += residual + x = self.act(x) + return x @register_model From 3ff19079f993e9a91206a04c2b61896883064eeb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 29 Jan 2020 13:11:38 -0800 Subject: [PATCH 08/13] Missed nn_ops.py from last commit --- timm/models/nn_ops.py | 146 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 timm/models/nn_ops.py diff --git a/timm/models/nn_ops.py b/timm/models/nn_ops.py new file mode 100644 index 00000000..37286611 --- /dev/null +++ b/timm/models/nn_ops.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + +## Assembled CNN Tensorflow Impl +# +# def _bernoulli(shape, mean, seed=None, dtype=tf.float32): +# return tf.nn.relu(tf.sign(mean - tf.random_uniform(shape, minval=0, maxval=1, dtype=dtype, seed=seed))) +# +# +# def dropblock(x, keep_prob, block_size, gamma_scale=1.0, seed=None, name=None, +# data_format='channels_last', is_training=True): # pylint: disable=invalid-name +# """ +# Dropblock layer. For more details, refer to https://arxiv.org/abs/1810.12890 +# :param x: A floating point tensor. +# :param keep_prob: A scalar Tensor with the same type as x. The probability that each element is kept. +# :param block_size: The block size to drop +# :param gamma_scale: The multiplier to gamma. +# :param seed: Python integer. Used to create random seeds. +# :param name: A name for this operation (optional) +# :param data_format: 'channels_last' or 'channels_first' +# :param is_training: If False, do nothing. +# :return: A Tensor of the same shape of x. +# """ +# if not is_training: +# return x +# +# # Early return if nothing needs to be dropped. +# if (isinstance(keep_prob, float) and keep_prob == 1) or gamma_scale == 0: +# return x +# +# with tf.name_scope(name, "dropblock", [x]) as name: +# if not x.dtype.is_floating: +# raise ValueError("x has to be a floating point tensor since it's going to" +# " be scaled. Got a %s tensor instead." % x.dtype) +# if isinstance(keep_prob, float) and not 0 < keep_prob <= 1: +# raise ValueError("keep_prob must be a scalar tensor or a float in the " +# "range (0, 1], got %g" % keep_prob) +# +# br = (block_size - 1) // 2 +# tl = (block_size - 1) - br +# if data_format == 'channels_last': +# _, h, w, c = x.shape.as_list() +# sampling_mask_shape = tf.stack([1, h - block_size + 1, w - block_size + 1, c]) +# pad_shape = [[0, 0], [tl, br], [tl, br], [0, 0]] +# elif data_format == 'channels_first': +# _, c, h, w = x.shape.as_list() +# sampling_mask_shape = tf.stack([1, c, h - block_size + 1, w - block_size + 1]) +# pad_shape = [[0, 0], [0, 0], [tl, br], [tl, br]] +# else: +# raise NotImplementedError +# +# gamma = (1. - keep_prob) * (w * h) / (block_size ** 2) / ((w - block_size + 1) * (h - block_size + 1)) +# gamma = gamma_scale * gamma +# mask = _bernoulli(sampling_mask_shape, gamma, seed, tf.float32) +# mask = tf.pad(mask, pad_shape) +# +# xdtype_mask = tf.cast(mask, x.dtype) +# xdtype_mask = tf.layers.max_pooling2d( +# inputs=xdtype_mask, pool_size=block_size, +# strides=1, padding='SAME', +# data_format=data_format) +# +# xdtype_mask = 1 - xdtype_mask +# fp32_mask = tf.cast(xdtype_mask, tf.float32) +# ret = tf.multiply(x, xdtype_mask) +# float32_mask_size = tf.cast(tf.size(fp32_mask), tf.float32) +# float32_mask_reduce_sum = tf.reduce_sum(fp32_mask) +# normalize_factor = tf.cast(float32_mask_size / (float32_mask_reduce_sum + 1e-8), x.dtype) +# ret = ret * normalize_factor +# return ret + + +def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False): + _, _, height, width = x.shape + total_size = width * height + clipped_block_size = min(block_size, min(width, height)) + # seed_drop_rate, the gamma parameter + seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (width - block_size + 1) * + (height - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(width), torch.arange(height)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, height, width)) + valid_block = valid_block.to(x.dtype) + + uniform_noise = torch.rand_like(x) + block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).to(x.dtype) + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2) + + if drop_with_noise: + normal_noise = torch.randn_like(x) + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = block_mask.numel() / (torch.sum(block_mask, dtype=torch.float32) + 1e-7) + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + def __init__(self, + drop_prob=0.1, + block_size=7, + gamma_scale=1.0, + with_noise=False): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + return drop_block_2d(x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise) + + +def drop_path(x, drop_prob=0.): + """Drop paths (Stochastic Depth) per sample (when applied in residual blocks).""" + keep_prob = 1 - drop_prob + random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.ModuleDict): + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + return drop_path(x, self.drop_prob) From ef457555d3fdf53ba9ec7765bf9477c5c86b84e6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 29 Jan 2020 14:34:45 -0800 Subject: [PATCH 09/13] BlockDrop working on GPU --- timm/models/nn_ops.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/timm/models/nn_ops.py b/timm/models/nn_ops.py index 37286611..9b931efb 100644 --- a/timm/models/nn_ops.py +++ b/timm/models/nn_ops.py @@ -83,14 +83,13 @@ def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noi (height - block_size + 1)) # Forces the block to be inside the feature map. - w_i, h_i = torch.meshgrid(torch.arange(width), torch.arange(height)) + w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device)) valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \ ((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2)) - valid_block = torch.reshape(valid_block, (1, 1, height, width)) - valid_block = valid_block.to(x.dtype) + valid_block = torch.reshape(valid_block, (1, 1, height, width)).float() - uniform_noise = torch.rand_like(x) - block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).to(x.dtype) + uniform_noise = torch.rand_like(x, dtype=torch.float32) + block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) block_mask = -F.max_pool2d( -block_mask, kernel_size=clipped_block_size, # block_size, From 355aa152d5c89f210ae0771f800598409807dacf Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 29 Jan 2020 14:51:34 -0800 Subject: [PATCH 10/13] Just leave it float for now, will look at fp16 later. Remove unused reference code. --- timm/models/nn_ops.py | 72 ++----------------------------------------- 1 file changed, 2 insertions(+), 70 deletions(-) diff --git a/timm/models/nn_ops.py b/timm/models/nn_ops.py index 9b931efb..30b98427 100644 --- a/timm/models/nn_ops.py +++ b/timm/models/nn_ops.py @@ -4,74 +4,6 @@ import torch.nn.functional as F import numpy as np import math -## Assembled CNN Tensorflow Impl -# -# def _bernoulli(shape, mean, seed=None, dtype=tf.float32): -# return tf.nn.relu(tf.sign(mean - tf.random_uniform(shape, minval=0, maxval=1, dtype=dtype, seed=seed))) -# -# -# def dropblock(x, keep_prob, block_size, gamma_scale=1.0, seed=None, name=None, -# data_format='channels_last', is_training=True): # pylint: disable=invalid-name -# """ -# Dropblock layer. For more details, refer to https://arxiv.org/abs/1810.12890 -# :param x: A floating point tensor. -# :param keep_prob: A scalar Tensor with the same type as x. The probability that each element is kept. -# :param block_size: The block size to drop -# :param gamma_scale: The multiplier to gamma. -# :param seed: Python integer. Used to create random seeds. -# :param name: A name for this operation (optional) -# :param data_format: 'channels_last' or 'channels_first' -# :param is_training: If False, do nothing. -# :return: A Tensor of the same shape of x. -# """ -# if not is_training: -# return x -# -# # Early return if nothing needs to be dropped. -# if (isinstance(keep_prob, float) and keep_prob == 1) or gamma_scale == 0: -# return x -# -# with tf.name_scope(name, "dropblock", [x]) as name: -# if not x.dtype.is_floating: -# raise ValueError("x has to be a floating point tensor since it's going to" -# " be scaled. Got a %s tensor instead." % x.dtype) -# if isinstance(keep_prob, float) and not 0 < keep_prob <= 1: -# raise ValueError("keep_prob must be a scalar tensor or a float in the " -# "range (0, 1], got %g" % keep_prob) -# -# br = (block_size - 1) // 2 -# tl = (block_size - 1) - br -# if data_format == 'channels_last': -# _, h, w, c = x.shape.as_list() -# sampling_mask_shape = tf.stack([1, h - block_size + 1, w - block_size + 1, c]) -# pad_shape = [[0, 0], [tl, br], [tl, br], [0, 0]] -# elif data_format == 'channels_first': -# _, c, h, w = x.shape.as_list() -# sampling_mask_shape = tf.stack([1, c, h - block_size + 1, w - block_size + 1]) -# pad_shape = [[0, 0], [0, 0], [tl, br], [tl, br]] -# else: -# raise NotImplementedError -# -# gamma = (1. - keep_prob) * (w * h) / (block_size ** 2) / ((w - block_size + 1) * (h - block_size + 1)) -# gamma = gamma_scale * gamma -# mask = _bernoulli(sampling_mask_shape, gamma, seed, tf.float32) -# mask = tf.pad(mask, pad_shape) -# -# xdtype_mask = tf.cast(mask, x.dtype) -# xdtype_mask = tf.layers.max_pooling2d( -# inputs=xdtype_mask, pool_size=block_size, -# strides=1, padding='SAME', -# data_format=data_format) -# -# xdtype_mask = 1 - xdtype_mask -# fp32_mask = tf.cast(xdtype_mask, tf.float32) -# ret = tf.multiply(x, xdtype_mask) -# float32_mask_size = tf.cast(tf.size(fp32_mask), tf.float32) -# float32_mask_reduce_sum = tf.reduce_sum(fp32_mask) -# normalize_factor = tf.cast(float32_mask_size / (float32_mask_reduce_sum + 1e-8), x.dtype) -# ret = ret * normalize_factor -# return ret - def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False): _, _, height, width = x.shape @@ -89,7 +21,7 @@ def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noi valid_block = torch.reshape(valid_block, (1, 1, height, width)).float() uniform_noise = torch.rand_like(x, dtype=torch.float32) - block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) + block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float() block_mask = -F.max_pool2d( -block_mask, kernel_size=clipped_block_size, # block_size, @@ -100,7 +32,7 @@ def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noi normal_noise = torch.randn_like(x) x = x * block_mask + normal_noise * (1 - block_mask) else: - normalize_scale = block_mask.numel() / (torch.sum(block_mask, dtype=torch.float32) + 1e-7) + normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7) x = x * block_mask * normalize_scale return x From a9d2424fd1680590146bcd4eed912cc84bbe6a5e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 Jan 2020 16:51:49 -0800 Subject: [PATCH 11/13] Add separate zero_init_last_bn function to support more block variety without a mess --- timm/models/res2net.py | 3 ++ timm/models/resnet.py | 86 +++++++++++++++++++++++++++--------------- timm/models/sknet.py | 6 +++ 3 files changed, 64 insertions(+), 31 deletions(-) diff --git a/timm/models/res2net.py b/timm/models/res2net.py index c83aba62..bcb7eaaf 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -87,6 +87,9 @@ class Bottle2neck(nn.Module): self.relu = act_layer(inplace=True) self.downsample = downsample + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + def forward(self, x): residual = x diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 3e0ce23e..d97a6aad 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -156,26 +156,38 @@ class BasicBlock(nn.Module): self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) def forward(self, x): residual = x - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.bn2(out) + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) if self.se is not None: - out = self.se(out) + x = self.se(x) - if self.downsample is not None: - residual = self.downsample(x) + if self.drop_path is not None: + x = self.drop_path(x) - out += residual - out = self.act2(out) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act2(x) - return out + return x class Bottleneck(nn.Module): @@ -207,31 +219,44 @@ class Bottleneck(nn.Module): self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) def forward(self, x): residual = x - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) - out = self.conv2(out) - out = self.bn2(out) - out = self.act2(out) + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act2(x) - out = self.conv3(out) - out = self.bn3(out) + x = self.conv3(x) + x = self.bn3(x) + if self.drop_block is not None: + x = self.drop_block(x) if self.se is not None: - out = self.se(out) + x = self.se(x) - if self.downsample is not None: - residual = self.downsample(x) + if self.drop_path is not None: + x = self.drop_path(x) - out += residual - out = self.act3(out) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act3(x) - return out + return x class ResNet(nn.Module): @@ -367,17 +392,16 @@ class ResNet(nn.Module): self.num_features = 512 * block.expansion self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - last_bn_name = 'bn3' if 'Bottle' in block.__name__ else 'bn2' for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - if zero_init_last_bn and 'layer' in n and last_bn_name in n: - # Initialize weight/gamma of last BN in each residual block to zero - nn.init.constant_(m.weight, 0.) - else: - nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, use_se=False, avg_down=False, down_kernel_size=1, **kwargs): diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 41e19075..7cf4fbe6 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -63,6 +63,9 @@ class SelectiveKernelBasic(nn.Module): self.drop_block = drop_block self.drop_path = drop_path + def zero_init_last_bn(self): + nn.init.zeros_(self.conv2.bn.weight) + def forward(self, x): residual = x x = self.conv1(x) @@ -109,6 +112,9 @@ class SelectiveKernelBottleneck(nn.Module): self.drop_block = drop_block self.drop_path = drop_path + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + def forward(self, x): residual = x x = self.conv1(x) From 7d07ebb66075db35560befc921b8f3098e7f79f0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 1 Feb 2020 23:28:48 -0800 Subject: [PATCH 12/13] Adding some configs to sknet, incl ResNet50 variants from 'Compounding ... Assembled Techniques' paper and original SKNet50 --- timm/models/sknet.py | 84 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 7cf4fbe6..0c387e39 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -22,7 +22,10 @@ def _cfg(url='', **kwargs): default_cfgs = { 'skresnet18': _cfg(url=''), - 'skresnet26d': _cfg() + 'skresnet26d': _cfg(), + 'skresnet50': _cfg(), + 'skresnet50d': _cfg(), + 'skresnext50_32x4d': _cfg(), } @@ -131,6 +134,41 @@ class SelectiveKernelBottleneck(nn.Module): return x +@register_model +def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + split_input=True + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-26 model. @@ -150,15 +188,17 @@ def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNet-18 model. +def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Select Kernel ResNet-50 model. + Based on config in "Compounding the Performance Improvements of Assembled Techniques in a + Convolutional Neural Network" """ - default_cfg = default_cfgs['skresnet18'] sk_kwargs = dict( - min_attn_channels=16, + attn_reduction=2, ) + default_cfg = default_cfgs['skresnet50'] model = ResNet( - SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), **kwargs) model.default_cfg = default_cfg if pretrained: @@ -167,18 +207,34 @@ def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNet-18 model. +def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Select Kernel ResNet-50-D model. + Based on config in "Compounding the Performance Improvements of Assembled Techniques in a + Convolutional Neural Network" """ - default_cfg = default_cfgs['skresnet18'] sk_kwargs = dict( - min_attn_channels=16, - split_input=True + attn_reduction=2, ) + default_cfg = default_cfgs['skresnet50d'] model = ResNet( - SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, - block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to + the SKNet50 model in the Select Kernel Paper + """ + default_cfg = default_cfgs['skresnext50_32x4d'] + model = ResNet( + SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, + num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) - return model \ No newline at end of file + return model From 13e8da2b46d8b48fa4bdc76dd89cd7aaf3f7d615 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 7 Feb 2020 22:42:04 -0800 Subject: [PATCH 13/13] SelectKernel split_input works best when input channels split like grouped conv, but output is full width. Disable zero_init for SK nets, seems a bad combo. --- timm/models/conv2d_layers.py | 22 +++++++--------------- timm/models/sknet.py | 12 +++++++----- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py index 5b1a44e8..feaf653c 100644 --- a/timm/models/conv2d_layers.py +++ b/timm/models/conv2d_layers.py @@ -311,15 +311,13 @@ class SelectiveKernelConv(nn.Module): kernel_size = [3] * len(kernel_size) else: dilation = [dilation] * len(kernel_size) - num_paths = len(kernel_size) - self.num_paths = num_paths - self.split_input = split_input + self.num_paths = len(kernel_size) self.in_channels = in_channels self.out_channels = out_channels - if split_input: - assert in_channels % num_paths == 0 and out_channels % num_paths == 0 - in_channels = in_channels // num_paths - out_channels = out_channels // num_paths + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths groups = min(out_channels, groups) conv_kwargs = dict( @@ -329,7 +327,7 @@ class SelectiveKernelConv(nn.Module): for k, d in zip(kernel_size, dilation)]) attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) - self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) self.drop_block = drop_block def forward(self, x): @@ -338,16 +336,10 @@ class SelectiveKernelConv(nn.Module): x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] else: x_paths = [op(x) for op in self.paths] - x = torch.stack(x_paths, dim=1) x_attn = self.attn(x) x = x * x_attn - - if self.split_input: - B, N, C, H, W = x.shape - x = x.reshape(B, N * C, H, W) - else: - x = torch.sum(x, dim=1) + x = torch.sum(x, dim=1) return x diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 0c387e39..4b02d501 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -158,11 +158,12 @@ def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['skresnet18'] sk_kwargs = dict( min_attn_channels=16, + attn_reduction=8, split_input=True ) model = ResNet( SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, - block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -179,7 +180,7 @@ def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): ) model = ResNet( SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False **kwargs) model.default_cfg = default_cfg if pretrained: @@ -199,7 +200,7 @@ def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['skresnet50'] model = ResNet( SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, - block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -218,7 +219,8 @@ def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['skresnet50d'] model = ResNet( SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -233,7 +235,7 @@ def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['skresnext50_32x4d'] model = ResNet( SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) + num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans)