From c8b3d6b81a478ec72b8d5f75015b3859af926df1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jan 2020 19:45:05 -0800 Subject: [PATCH 01/23] Initial impl of Selective Kernel Networks. Very much a WIP. --- timm/models/resnet.py | 146 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 422eb0cb..b3adf6dc 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -6,6 +6,7 @@ additional dropout and dynamic global avg/max pool. ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman """ import math +from collections import OrderedDict import torch import torch.nn as nn @@ -100,6 +101,7 @@ default_cfgs = { 'seresnext26tn_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', interpolation='bicubic'), + 'skresnet26d': _cfg() } @@ -232,6 +234,137 @@ class Bottleneck(nn.Module): return out +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, num_attn_feat=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, num_attn_feat, kernel_size=1, bias=False) + self.bn = norm_layer(num_attn_feat) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(num_attn_feat, channels * num_paths, kernel_size=1) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + #print('attn sum', x.shape) + x = self.pool(x) + #print('attn pool', x.shape) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + #print('attn sel', x.shape) + x = x.view((x.shape[0], self.num_paths, x.shape[1]//self.num_paths) + x.shape[-2:]) + #print('attn spl', x.shape) + x = torch.softmax(x, dim=1) + return x + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=[3, 5], attn_reduction=16, + min_attn_feat=32, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + if not isinstance(kernel_size, list): + assert kernel_size >= 3 and kernel_size % 2 + kernel_size = [kernel_size] * 2 + else: + # FIXME assert kernel sizes >=3 and odd + pass + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + groups = min(out_channels // len(kernel_size), groups) + + self.conv_paths = nn.ModuleList() + for k, d in zip(kernel_size, dilation): + p = _get_padding(k, stride, d) + self.conv_paths.append(nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)), + ('bn', norm_layer(out_channels)), + ('act', act_layer(inplace=True)) + ]))) + + if use_attn: + num_attn_feat = max(int(out_channels / attn_reduction), min_attn_feat) + self.attn = SelectiveKernelAttn(out_channels, len(kernel_size), num_attn_feat) + else: + self.attn = None + + def forward(self, x): + x_paths = [] + for conv in self.conv_paths: + xk = conv(x) + x_paths.append(xk) + if self.attn is not None: + x_paths = torch.stack(x_paths, dim=1) + # print('paths', x_paths.shape) + x_attn = self.attn(x_paths) + #print('attn', x_attn.shape) + x = x_paths * x_attn + #print('amul', x.shape) + x = torch.sum(x, dim=1) + #print('asum', x.shape) + else: + x = torch.cat(x_paths, dim=1) + return x + + +class SelectiveKernelBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBottleneck, self).__init__() + + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + self.conv2 = SelectiveKernelConv( + first_planes, width, stride=stride, dilation=dilation, groups=cardinality) + self.bn2 = norm_layer(width) + self.act2 = act_layer(inplace=True) + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act2(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act3(out) + + return out + + class ResNet(nn.Module): """ResNet / ResNeXt / SE-ResNeXt / SE-Net @@ -472,6 +605,19 @@ def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-26 model. + """ + default_cfg = default_cfgs['skresnet26d'] + model = ResNet( + SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + @register_model def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-26 v1d model. From ad087b4b1785d5a6f0d2404b8a30ee9ee7add8f0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jan 2020 19:54:37 -0800 Subject: [PATCH 02/23] Missed bias=False in selection conv --- timm/models/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index b3adf6dc..57a20894 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -243,7 +243,7 @@ class SelectiveKernelAttn(nn.Module): self.fc_reduce = nn.Conv2d(channels, num_attn_feat, kernel_size=1, bias=False) self.bn = norm_layer(num_attn_feat) self.act = act_layer(inplace=True) - self.fc_select = nn.Conv2d(num_attn_feat, channels * num_paths, kernel_size=1) + self.fc_select = nn.Conv2d(num_attn_feat, channels * num_paths, kernel_size=1, bias=False) def forward(self, x): assert x.shape[1] == self.num_paths From a93bae6dc5a8831f1633208ac45d598a225810c5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 25 Jan 2020 18:31:08 -0800 Subject: [PATCH 03/23] A SelectiveKernelBasicBlock for more experiments --- timm/models/resnet.py | 61 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 57a20894..1d64dcd9 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -265,7 +265,7 @@ class SelectiveKernelAttn(nn.Module): class SelectiveKernelConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=[3, 5], attn_reduction=16, - min_attn_feat=32, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True, + min_attn_feat=16, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelConv, self).__init__() if not isinstance(kernel_size, list): @@ -316,6 +316,53 @@ class SelectiveKernelConv(nn.Module): return x +class SelectiveKernelBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBasicBlock, self).__init__() + + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock doest not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, + dilation=dilation, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + self.conv2 = SelectiveKernelConv(first_planes, outplanes, dilation=previous_dilation) + self.bn2 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.act2 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act2(out) + + return out + + class SelectiveKernelBottleneck(nn.Module): expansion = 4 @@ -581,6 +628,18 @@ def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['resnet18'] + model = ResNet(SelectiveKernelBasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-34 model. From 58e28dc7e7a988fe3566aa15150cfc76261eccc7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 26 Jan 2020 11:23:39 -0800 Subject: [PATCH 04/23] Move Selective Kernel blocks/convs to their own sknet.py file --- timm/models/__init__.py | 1 + timm/models/resnet.py | 209 +--------------------------- timm/models/sknet.py | 294 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 297 insertions(+), 207 deletions(-) create mode 100644 timm/models/sknet.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 0fa4d210..69e09085 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -16,6 +16,7 @@ from .gluon_xception import * from .res2net import * from .dla import * from .hrnet import * +from .sknet import * from .registry import * from .factory import create_model diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 1d64dcd9..c3104530 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -6,7 +6,6 @@ additional dropout and dynamic global avg/max pool. ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman """ import math -from collections import OrderedDict import torch import torch.nn as nn @@ -101,11 +100,10 @@ default_cfgs = { 'seresnext26tn_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', interpolation='bicubic'), - 'skresnet26d': _cfg() } -def _get_padding(kernel_size, stride, dilation=1): +def get_padding(kernel_size, stride, dilation=1): padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 return padding @@ -234,184 +232,6 @@ class Bottleneck(nn.Module): return out -class SelectiveKernelAttn(nn.Module): - def __init__(self, channels, num_paths=2, num_attn_feat=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelAttn, self).__init__() - self.num_paths = num_paths - self.pool = nn.AdaptiveAvgPool2d(1) - self.fc_reduce = nn.Conv2d(channels, num_attn_feat, kernel_size=1, bias=False) - self.bn = norm_layer(num_attn_feat) - self.act = act_layer(inplace=True) - self.fc_select = nn.Conv2d(num_attn_feat, channels * num_paths, kernel_size=1, bias=False) - - def forward(self, x): - assert x.shape[1] == self.num_paths - x = torch.sum(x, dim=1) - #print('attn sum', x.shape) - x = self.pool(x) - #print('attn pool', x.shape) - x = self.fc_reduce(x) - x = self.bn(x) - x = self.act(x) - x = self.fc_select(x) - #print('attn sel', x.shape) - x = x.view((x.shape[0], self.num_paths, x.shape[1]//self.num_paths) + x.shape[-2:]) - #print('attn spl', x.shape) - x = torch.softmax(x, dim=1) - return x - - -class SelectiveKernelConv(nn.Module): - - def __init__(self, in_channels, out_channels, kernel_size=[3, 5], attn_reduction=16, - min_attn_feat=16, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelConv, self).__init__() - if not isinstance(kernel_size, list): - assert kernel_size >= 3 and kernel_size % 2 - kernel_size = [kernel_size] * 2 - else: - # FIXME assert kernel sizes >=3 and odd - pass - if keep_3x3: - dilation = [dilation * (k - 1) // 2 for k in kernel_size] - kernel_size = [3] * len(kernel_size) - else: - dilation = [dilation] * len(kernel_size) - groups = min(out_channels // len(kernel_size), groups) - - self.conv_paths = nn.ModuleList() - for k, d in zip(kernel_size, dilation): - p = _get_padding(k, stride, d) - self.conv_paths.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d( - in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)), - ('bn', norm_layer(out_channels)), - ('act', act_layer(inplace=True)) - ]))) - - if use_attn: - num_attn_feat = max(int(out_channels / attn_reduction), min_attn_feat) - self.attn = SelectiveKernelAttn(out_channels, len(kernel_size), num_attn_feat) - else: - self.attn = None - - def forward(self, x): - x_paths = [] - for conv in self.conv_paths: - xk = conv(x) - x_paths.append(xk) - if self.attn is not None: - x_paths = torch.stack(x_paths, dim=1) - # print('paths', x_paths.shape) - x_attn = self.attn(x_paths) - #print('attn', x_attn.shape) - x = x_paths * x_attn - #print('amul', x.shape) - x = torch.sum(x, dim=1) - #print('asum', x.shape) - else: - x = torch.cat(x_paths, dim=1) - return x - - -class SelectiveKernelBasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelBasicBlock, self).__init__() - - assert cardinality == 1, 'BasicBlock only supports cardinality of 1' - assert base_width == 64, 'BasicBlock doest not support changing base width' - first_planes = planes // reduce_first - outplanes = planes * self.expansion - - self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, bias=False) - self.bn1 = norm_layer(first_planes) - self.act1 = act_layer(inplace=True) - self.conv2 = SelectiveKernelConv(first_planes, outplanes, dilation=previous_dilation) - self.bn2 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.act2 = act_layer(inplace=True) - self.downsample = downsample - self.stride = stride - self.dilation = dilation - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.bn2(out) - - if self.se is not None: - out = self.se(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.act2(out) - - return out - - -class SelectiveKernelBottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelBottleneck, self).__init__() - - width = int(math.floor(planes * (base_width / 64)) * cardinality) - first_planes = width // reduce_first - outplanes = planes * self.expansion - - self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) - self.bn1 = norm_layer(first_planes) - self.act1 = act_layer(inplace=True) - self.conv2 = SelectiveKernelConv( - first_planes, width, stride=stride, dilation=dilation, groups=cardinality) - self.bn2 = norm_layer(width) - self.act2 = act_layer(inplace=True) - self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) - self.bn3 = norm_layer(outplanes) - self.act3 = act_layer(inplace=True) - self.downsample = downsample - self.stride = stride - self.dilation = dilation - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.act2(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.act3(out) - - return out - - class ResNet(nn.Module): """ResNet / ResNeXt / SE-ResNeXt / SE-Net @@ -560,7 +380,7 @@ class ResNet(nn.Module): downsample = None down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size if stride != 1 or self.inplanes != planes * block.expansion: - downsample_padding = _get_padding(down_kernel_size, stride) + downsample_padding = get_padding(down_kernel_size, stride) downsample_layers = [] conv_stride = stride if avg_down: @@ -628,18 +448,6 @@ def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model -@register_model -def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNet-18 model. - """ - default_cfg = default_cfgs['resnet18'] - model = ResNet(SelectiveKernelBasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - @register_model def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-34 model. @@ -664,19 +472,6 @@ def resnet26(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model -@register_model -def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNet-26 model. - """ - default_cfg = default_cfgs['skresnet26d'] - model = ResNet( - SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - @register_model def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-26 v1d model. diff --git a/timm/models/sknet.py b/timm/models/sknet.py new file mode 100644 index 00000000..4bc2061d --- /dev/null +++ b/timm/models/sknet.py @@ -0,0 +1,294 @@ +import math +from collections import OrderedDict + +import torch +from torch import nn as nn + +from timm.models.registry import register_model +from timm.models.helpers import load_pretrained +from timm.models.resnet import ResNet, get_padding, SEModule +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'conv1', 'classifier': 'fc', + **kwargs + } + + +default_cfgs = { + 'skresnet18': _cfg(url=''), + 'skresnet26d': _cfg() +} + + +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + #print('attn sum', x.shape) + x = self.pool(x) + #print('attn pool', x.shape) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + #print('attn sel', x.shape) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + #print('attn spl', x.shape) + x = torch.softmax(x, dim=1) + return x + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=[3, 5], stride=1, dilation=1, groups=1, + attn_reduction=16, min_attn_channels=32, keep_3x3=True, use_attn=True, + split_input=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + num_paths = len(kernel_size) + self.num_paths = num_paths + self.split_input = split_input + self.in_channels = in_channels + self.out_channels = out_channels + if split_input: + assert in_channels % num_paths == 0 and out_channels % num_paths == 0 + in_channels = in_channels // num_paths + out_channels = out_channels // num_paths + groups = min(out_channels, groups) + + self.paths = nn.ModuleList() + for k, d in zip(kernel_size, dilation): + p = get_padding(k, stride, d) + self.paths.append(nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)), + ('bn', norm_layer(out_channels)), + ('act', act_layer(inplace=True)) + ]))) + + if use_attn: + attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) + else: + self.attn = None + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.out_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + + if self.attn is not None: + x = torch.stack(x_paths, dim=1) + # print('paths', x_paths.shape) + x_attn = self.attn(x) + #print('attn', x_attn.shape) + x = x * x_attn + #print('amul', x.shape) + + if self.split_input: + B, N, C, H, W = x.shape + x = x.reshape(B, N * C, H, W) + else: + x = torch.sum(x, dim=1) + #print('aout', x.shape) + return x + + +class SelectiveKernelBasic(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, sk_kwargs=None, + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBasic, self).__init__() + + sk_kwargs = sk_kwargs or {} + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock doest not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + + _selective_first = True # FIXME temporary, for experiments + if _selective_first: + self.conv1 = SelectiveKernelConv( + inplanes, first_planes, stride=stride, dilation=dilation, **sk_kwargs) + else: + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, + dilation=dilation, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + if _selective_first: + self.conv2 = nn.Conv2d( + first_planes, outplanes, kernel_size=3, padding=previous_dilation, + dilation=previous_dilation, bias=False) + else: + self.conv2 = SelectiveKernelConv( + first_planes, outplanes, dilation=previous_dilation, **sk_kwargs) + self.bn2 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.act2 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act2(out) + + return out + + +class SelectiveKernelBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, sk_kwargs=None, + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBottleneck, self).__init__() + + sk_kwargs = sk_kwargs or {} + width = int(math.floor(planes * (base_width / 64)) * cardinality) + first_planes = width // reduce_first + outplanes = planes * self.expansion + + self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + self.conv2 = SelectiveKernelConv( + first_planes, width, stride=stride, dilation=dilation, groups=cardinality, **sk_kwargs) + self.bn2 = norm_layer(width) + self.act2 = act_layer(inplace=True) + self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) + self.bn3 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.act3 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.act2(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act3(out) + + return out + + +@register_model +def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-26 model. + """ + default_cfg = default_cfgs['skresnet26d'] + sk_kwargs = dict( + keep_3x3=False, + ) + model = ResNet( + SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), + **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + split_input=True + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model \ No newline at end of file From 9abe6109316bca240ca71a32806b1e6e39979ce0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 26 Jan 2020 11:33:31 -0800 Subject: [PATCH 05/23] Used wrong channel var for split --- timm/models/sknet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 4bc2061d..d3a4fb5d 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -106,7 +106,7 @@ class SelectiveKernelConv(nn.Module): def forward(self, x): if self.split_input: - x_split = torch.split(x, self.out_channels // self.num_paths, 1) + x_split = torch.split(x, self.in_channels // self.num_paths, 1) x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] else: x_paths = [op(x) for op in self.paths] From cefc9b7761584f7118e65a6bbdcc3887f3b0de62 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 27 Jan 2020 21:48:28 -0800 Subject: [PATCH 06/23] Move SelectKernelConv to conv2d_layers and more * always apply attention in SelectKernelConv, leave MixedConv for no attention alternative * make MixedConv torchscript compatible * refactor first/previous dilation name to make more sense in ResNet* networks --- timm/models/conv2d_layers.py | 102 ++++++++++++++++++++++++++-- timm/models/res2net.py | 9 +-- timm/models/resnet.py | 33 +++++---- timm/models/sknet.py | 126 ++++------------------------------- 4 files changed, 128 insertions(+), 142 deletions(-) diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py index acd14fde..7583263a 100644 --- a/timm/models/conv2d_layers.py +++ b/timm/models/conv2d_layers.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import torch import torch.nn as nn import torch.nn.functional as F @@ -100,14 +102,11 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) -class MixedConv2d(nn.Module): +class MixedConv2d(nn.ModuleDict): """ Mixed Grouped Convolution Based on MDConv and GroupedConv in MixNet impl: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py - - NOTE: This does not currently work with torch.jit.script """ - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', dilation=1, depthwise=False, **kwargs): super(MixedConv2d, self).__init__() @@ -131,7 +130,7 @@ class MixedConv2d(nn.Module): def forward(self, x): x_split = torch.split(x, self.splits, 1) - x_out = [c(x) for x, c in zip(x_split, self._modules.values())] + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] x = torch.cat(x_out, 1) return x @@ -240,6 +239,97 @@ class CondConv2d(nn.Module): return out +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + x = self.pool(x) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, + attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + kernel_size = kernel_size or [3, 5] + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + num_paths = len(kernel_size) + self.num_paths = num_paths + self.split_input = split_input + self.in_channels = in_channels + self.out_channels = out_channels + if split_input: + assert in_channels % num_paths == 0 and out_channels % num_paths == 0 + in_channels = in_channels // num_paths + out_channels = out_channels // num_paths + groups = min(out_channels, groups) + + self.paths = nn.ModuleList() + for k, d in zip(kernel_size, dilation): + p = _get_padding(k, stride, d) + self.paths.append(nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + in_channels, out_channels, kernel_size=k, stride=stride, padding=p, + dilation=d, groups=groups, bias=False)), + ('bn', norm_layer(out_channels)), + ('act', act_layer(inplace=True)) + ]))) + + attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + + if self.split_input: + B, N, C, H, W = x.shape + x = x.reshape(B, N * C, H, W) + else: + x = torch.sum(x, dim=1) + return x + + # helper method def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): assert 'groups' not in kwargs # only use 'depthwise' bool arg @@ -256,5 +346,3 @@ def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): else: m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) return m - - diff --git a/timm/models/res2net.py b/timm/models/res2net.py index da20e7a0..c83aba62 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -54,14 +54,15 @@ class Bottle2neck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=26, scale=4, use_se=False, - act_layer=nn.ReLU, norm_layer=None, dilation=1, previous_dilation=1, **_): + act_layer=nn.ReLU, norm_layer=None, dilation=1, first_dilation=None, **_): super(Bottle2neck, self).__init__() self.scale = scale self.is_first = stride > 1 or downsample is not None self.num_scales = max(1, scale - 1) width = int(math.floor(planes * (base_width / 64.0))) * cardinality - outplanes = planes * self.expansion self.width = width + outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) self.bn1 = norm_layer(width * scale) @@ -70,8 +71,8 @@ class Bottle2neck(nn.Module): bns = [] for i in range(self.num_scales): convs.append(nn.Conv2d( - width, width, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, groups=cardinality, bias=False)) + width, width, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, groups=cardinality, bias=False)) bns.append(norm_layer(width)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index c3104530..976d4234 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -131,24 +131,23 @@ class BasicBlock(nn.Module): __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, bias=False) + inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=previous_dilation, - dilation=previous_dilation, bias=False) + first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) self.bn2 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None self.act2 = act_layer(inplace=True) @@ -181,21 +180,21 @@ class Bottleneck(nn.Module): __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d( first_planes, width, kernel_size=3, stride=stride, - padding=dilation, dilation=dilation, groups=cardinality, bias=False) + padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) @@ -396,13 +395,11 @@ class ResNet(nn.Module): first_dilation = 1 if dilation in (1, 2) else 2 bkwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - use_se=use_se, **kwargs) - layers = [block( - self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)] + dilation=dilation, use_se=use_se, **kwargs) + layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **bkwargs)] self.inplanes = planes * block.expansion for i in range(1, blocks): - layers.append(block( - self.inplanes, planes, dilation=dilation, previous_dilation=dilation, **bkwargs)) + layers.append(block(self.inplanes, planes, **bkwargs)) return nn.Sequential(*layers) @@ -430,8 +427,8 @@ class ResNet(nn.Module): def forward(self, x): x = self.forward_features(x) x = self.global_pool(x).flatten(1) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) + if self.drop_rate: + x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = self.fc(x) return x diff --git a/timm/models/sknet.py b/timm/models/sknet.py index d3a4fb5d..e9dbf68d 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -1,12 +1,11 @@ import math -from collections import OrderedDict -import torch from torch import nn as nn from timm.models.registry import register_model from timm.models.helpers import load_pretrained -from timm.models.resnet import ResNet, get_padding, SEModule +from timm.models.conv2d_layers import SelectiveKernelConv +from timm.models.resnet import ResNet, SEModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -27,113 +26,12 @@ default_cfgs = { } -class SelectiveKernelAttn(nn.Module): - def __init__(self, channels, num_paths=2, attn_channels=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelAttn, self).__init__() - self.num_paths = num_paths - self.pool = nn.AdaptiveAvgPool2d(1) - self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) - self.bn = norm_layer(attn_channels) - self.act = act_layer(inplace=True) - self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) - - def forward(self, x): - assert x.shape[1] == self.num_paths - x = torch.sum(x, dim=1) - #print('attn sum', x.shape) - x = self.pool(x) - #print('attn pool', x.shape) - x = self.fc_reduce(x) - x = self.bn(x) - x = self.act(x) - x = self.fc_select(x) - #print('attn sel', x.shape) - B, C, H, W = x.shape - x = x.view(B, self.num_paths, C // self.num_paths, H, W) - #print('attn spl', x.shape) - x = torch.softmax(x, dim=1) - return x - - -def _kernel_valid(k): - if isinstance(k, (list, tuple)): - for ki in k: - return _kernel_valid(ki) - assert k >= 3 and k % 2 - - -class SelectiveKernelConv(nn.Module): - - def __init__(self, in_channels, out_channels, kernel_size=[3, 5], stride=1, dilation=1, groups=1, - attn_reduction=16, min_attn_channels=32, keep_3x3=True, use_attn=True, - split_input=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelConv, self).__init__() - _kernel_valid(kernel_size) - if not isinstance(kernel_size, list): - kernel_size = [kernel_size] * 2 - if keep_3x3: - dilation = [dilation * (k - 1) // 2 for k in kernel_size] - kernel_size = [3] * len(kernel_size) - else: - dilation = [dilation] * len(kernel_size) - num_paths = len(kernel_size) - self.num_paths = num_paths - self.split_input = split_input - self.in_channels = in_channels - self.out_channels = out_channels - if split_input: - assert in_channels % num_paths == 0 and out_channels % num_paths == 0 - in_channels = in_channels // num_paths - out_channels = out_channels // num_paths - groups = min(out_channels, groups) - - self.paths = nn.ModuleList() - for k, d in zip(kernel_size, dilation): - p = get_padding(k, stride, d) - self.paths.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d( - in_channels, out_channels, kernel_size=k, stride=stride, padding=p, dilation=d, groups=groups)), - ('bn', norm_layer(out_channels)), - ('act', act_layer(inplace=True)) - ]))) - - if use_attn: - attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) - self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) - else: - self.attn = None - - def forward(self, x): - if self.split_input: - x_split = torch.split(x, self.in_channels // self.num_paths, 1) - x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] - else: - x_paths = [op(x) for op in self.paths] - - if self.attn is not None: - x = torch.stack(x_paths, dim=1) - # print('paths', x_paths.shape) - x_attn = self.attn(x) - #print('attn', x_attn.shape) - x = x * x_attn - #print('amul', x.shape) - - if self.split_input: - B, N, C, H, W = x.shape - x = x.reshape(B, N * C, H, W) - else: - x = torch.sum(x, dim=1) - #print('aout', x.shape) - return x - - class SelectiveKernelBasic(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelBasic, self).__init__() sk_kwargs = sk_kwargs or {} @@ -141,24 +39,25 @@ class SelectiveKernelBasic(nn.Module): assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation _selective_first = True # FIXME temporary, for experiments if _selective_first: self.conv1 = SelectiveKernelConv( - inplanes, first_planes, stride=stride, dilation=dilation, **sk_kwargs) + inplanes, first_planes, stride=stride, dilation=first_dilation, **sk_kwargs) else: self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, - dilation=dilation, bias=False) + inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, + dilation=first_dilation, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) if _selective_first: self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=previous_dilation, - dilation=previous_dilation, bias=False) + first_planes, outplanes, kernel_size=3, padding=dilation, + dilation=dilation, bias=False) else: self.conv2 = SelectiveKernelConv( - first_planes, outplanes, dilation=previous_dilation, **sk_kwargs) + first_planes, outplanes, dilation=dilation, **sk_kwargs) self.bn2 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None self.act2 = act_layer(inplace=True) @@ -192,19 +91,20 @@ class SelectiveKernelBottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelBottleneck, self).__init__() sk_kwargs = sk_kwargs or {} width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion + first_dilation = first_dilation or dilation self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) self.bn1 = norm_layer(first_planes) self.act1 = act_layer(inplace=True) self.conv2 = SelectiveKernelConv( - first_planes, width, stride=stride, dilation=dilation, groups=cardinality, **sk_kwargs) + first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, **sk_kwargs) self.bn2 = norm_layer(width) self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) From 9f11b4e8a25495874d84a56d4ca11af191a01324 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 29 Jan 2020 13:01:35 -0800 Subject: [PATCH 07/23] Add ConvBnAct layer to parallel integrated SelectKernelConv, add support for DropPath and DropBlock to ResNet base and SK blocks --- timm/models/conv2d_layers.py | 43 +++++++++---- timm/models/resnet.py | 20 +++--- timm/models/sknet.py | 118 +++++++++++++++-------------------- 3 files changed, 96 insertions(+), 85 deletions(-) diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py index 7583263a..5b1a44e8 100644 --- a/timm/models/conv2d_layers.py +++ b/timm/models/conv2d_layers.py @@ -271,11 +271,36 @@ def _kernel_valid(k): assert k >= 3 and k % 2 +class ConvBnAct(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(ConvBnAct, self).__init__() + padding = _get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block + self.conv = nn.Conv2d( + in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=False) + self.bn = norm_layer(out_channels) + self.drop_block = drop_block + if act_layer is not None: + self.act = act_layer(inplace=True) + else: + self.act = None + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.drop_block is not None: + x = self.drop_block(x) + if self.act is not None: + x = self.act(x) + return x + + class SelectiveKernelConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelConv, self).__init__() kernel_size = kernel_size or [3, 5] _kernel_valid(kernel_size) @@ -297,19 +322,15 @@ class SelectiveKernelConv(nn.Module): out_channels = out_channels // num_paths groups = min(out_channels, groups) - self.paths = nn.ModuleList() - for k, d in zip(kernel_size, dilation): - p = _get_padding(k, stride, d) - self.paths.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d( - in_channels, out_channels, kernel_size=k, stride=stride, padding=p, - dilation=d, groups=groups, bias=False)), - ('bn', norm_layer(out_channels)), - ('act', act_layer(inplace=True)) - ]))) + conv_kwargs = dict( + stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) + self.paths = nn.ModuleList([ + ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) + self.drop_block = drop_block def forward(self, x): if self.split_input: diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 976d4234..3e0ce23e 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .nn_ops import DropBlock2d, DropPath from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -132,7 +133,8 @@ class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_block=None, drop_path=None): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' @@ -181,7 +183,8 @@ class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_block=None, drop_path=None): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) @@ -305,8 +308,8 @@ class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, cardinality=1, base_width=64, stem_width=64, stem_type='', block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', - zero_init_last_bn=True, block_args=None): + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0., + drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): block_args = block_args or dict() self.num_classes = num_classes deep_stem = 'deep' in stem_type @@ -338,6 +341,9 @@ class ResNet(nn.Module): self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks + dp = DropPath(drop_path_rate) if drop_block_rate else None + db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None + db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4 if output_stride == 16: strides[3] = 1 @@ -350,11 +356,11 @@ class ResNet(nn.Module): llargs = list(zip(channels, layers, strides, dilations)) lkwargs = dict( use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, - avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args) + avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) self.layer1 = self._make_layer(block, *llargs[0], **lkwargs) self.layer2 = self._make_layer(block, *llargs[1], **lkwargs) - self.layer3 = self._make_layer(block, *llargs[2], **lkwargs) - self.layer4 = self._make_layer(block, *llargs[3], **lkwargs) + self.layer3 = self._make_layer(block, drop_block=db_3, *llargs[2], **lkwargs) + self.layer4 = self._make_layer(block, drop_block=db_4, *llargs[3], **lkwargs) # Head (Pooling and Classifier) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) diff --git a/timm/models/sknet.py b/timm/models/sknet.py index e9dbf68d..41e19075 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -4,7 +4,7 @@ from torch import nn as nn from timm.models.registry import register_model from timm.models.helpers import load_pretrained -from timm.models.conv2d_layers import SelectiveKernelConv +from timm.models.conv2d_layers import SelectiveKernelConv, ConvBnAct from timm.models.resnet import ResNet, SEModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -29,61 +29,53 @@ default_cfgs = { class SelectiveKernelBasic(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + use_se=False, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, + drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelBasic, self).__init__() sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first - outplanes = planes * self.expansion + out_planes = planes * self.expansion first_dilation = first_dilation or dilation _selective_first = True # FIXME temporary, for experiments if _selective_first: self.conv1 = SelectiveKernelConv( - inplanes, first_planes, stride=stride, dilation=first_dilation, **sk_kwargs) - else: - self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, - dilation=first_dilation, bias=False) - self.bn1 = norm_layer(first_planes) - self.act1 = act_layer(inplace=True) - if _selective_first: - self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=dilation, - dilation=dilation, bias=False) + inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv2 = ConvBnAct( + first_planes, out_planes, kernel_size=3, dilation=dilation, **conv_kwargs) else: + self.conv1 = ConvBnAct( + inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs) + conv_kwargs['act_layer'] = None self.conv2 = SelectiveKernelConv( - first_planes, outplanes, dilation=dilation, **sk_kwargs) - self.bn2 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.act2 = act_layer(inplace=True) + first_planes, out_planes, dilation=dilation, **conv_kwargs, **sk_kwargs) + self.se = SEModule(out_planes, planes // 4) if use_se else None + self.act = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path def forward(self, x): residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.bn2(out) - + x = self.conv1(x) + x = self.conv2(x) if self.se is not None: - out = self.se(out) - + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.act2(out) - - return out + residual = self.downsample(residual) + x += residual + x = self.act(x) + return x class SelectiveKernelBottleneck(nn.Module): @@ -91,54 +83,46 @@ class SelectiveKernelBottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, + drop_block=None, drop_path=None, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelBottleneck, self).__init__() sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first - outplanes = planes * self.expansion + out_planes = planes * self.expansion first_dilation = first_dilation or dilation - self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) - self.bn1 = norm_layer(first_planes) - self.act1 = act_layer(inplace=True) + self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) self.conv2 = SelectiveKernelConv( - first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, **sk_kwargs) - self.bn2 = norm_layer(width) - self.act2 = act_layer(inplace=True) - self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) - self.bn3 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.act3 = act_layer(inplace=True) + first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, + **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv3 = ConvBnAct(width, out_planes, kernel_size=1, **conv_kwargs) + self.se = SEModule(out_planes, planes // 4) if use_se else None + self.act = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path def forward(self, x): residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.act2(out) - - out = self.conv3(out) - out = self.bn3(out) - + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) if self.se is not None: - out = self.se(out) - + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.act3(out) - - return out + residual = self.downsample(residual) + x += residual + x = self.act(x) + return x @register_model From 3ff19079f993e9a91206a04c2b61896883064eeb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 29 Jan 2020 13:11:38 -0800 Subject: [PATCH 08/23] Missed nn_ops.py from last commit --- timm/models/nn_ops.py | 146 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 timm/models/nn_ops.py diff --git a/timm/models/nn_ops.py b/timm/models/nn_ops.py new file mode 100644 index 00000000..37286611 --- /dev/null +++ b/timm/models/nn_ops.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + +## Assembled CNN Tensorflow Impl +# +# def _bernoulli(shape, mean, seed=None, dtype=tf.float32): +# return tf.nn.relu(tf.sign(mean - tf.random_uniform(shape, minval=0, maxval=1, dtype=dtype, seed=seed))) +# +# +# def dropblock(x, keep_prob, block_size, gamma_scale=1.0, seed=None, name=None, +# data_format='channels_last', is_training=True): # pylint: disable=invalid-name +# """ +# Dropblock layer. For more details, refer to https://arxiv.org/abs/1810.12890 +# :param x: A floating point tensor. +# :param keep_prob: A scalar Tensor with the same type as x. The probability that each element is kept. +# :param block_size: The block size to drop +# :param gamma_scale: The multiplier to gamma. +# :param seed: Python integer. Used to create random seeds. +# :param name: A name for this operation (optional) +# :param data_format: 'channels_last' or 'channels_first' +# :param is_training: If False, do nothing. +# :return: A Tensor of the same shape of x. +# """ +# if not is_training: +# return x +# +# # Early return if nothing needs to be dropped. +# if (isinstance(keep_prob, float) and keep_prob == 1) or gamma_scale == 0: +# return x +# +# with tf.name_scope(name, "dropblock", [x]) as name: +# if not x.dtype.is_floating: +# raise ValueError("x has to be a floating point tensor since it's going to" +# " be scaled. Got a %s tensor instead." % x.dtype) +# if isinstance(keep_prob, float) and not 0 < keep_prob <= 1: +# raise ValueError("keep_prob must be a scalar tensor or a float in the " +# "range (0, 1], got %g" % keep_prob) +# +# br = (block_size - 1) // 2 +# tl = (block_size - 1) - br +# if data_format == 'channels_last': +# _, h, w, c = x.shape.as_list() +# sampling_mask_shape = tf.stack([1, h - block_size + 1, w - block_size + 1, c]) +# pad_shape = [[0, 0], [tl, br], [tl, br], [0, 0]] +# elif data_format == 'channels_first': +# _, c, h, w = x.shape.as_list() +# sampling_mask_shape = tf.stack([1, c, h - block_size + 1, w - block_size + 1]) +# pad_shape = [[0, 0], [0, 0], [tl, br], [tl, br]] +# else: +# raise NotImplementedError +# +# gamma = (1. - keep_prob) * (w * h) / (block_size ** 2) / ((w - block_size + 1) * (h - block_size + 1)) +# gamma = gamma_scale * gamma +# mask = _bernoulli(sampling_mask_shape, gamma, seed, tf.float32) +# mask = tf.pad(mask, pad_shape) +# +# xdtype_mask = tf.cast(mask, x.dtype) +# xdtype_mask = tf.layers.max_pooling2d( +# inputs=xdtype_mask, pool_size=block_size, +# strides=1, padding='SAME', +# data_format=data_format) +# +# xdtype_mask = 1 - xdtype_mask +# fp32_mask = tf.cast(xdtype_mask, tf.float32) +# ret = tf.multiply(x, xdtype_mask) +# float32_mask_size = tf.cast(tf.size(fp32_mask), tf.float32) +# float32_mask_reduce_sum = tf.reduce_sum(fp32_mask) +# normalize_factor = tf.cast(float32_mask_size / (float32_mask_reduce_sum + 1e-8), x.dtype) +# ret = ret * normalize_factor +# return ret + + +def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False): + _, _, height, width = x.shape + total_size = width * height + clipped_block_size = min(block_size, min(width, height)) + # seed_drop_rate, the gamma parameter + seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( + (width - block_size + 1) * + (height - block_size + 1)) + + # Forces the block to be inside the feature map. + w_i, h_i = torch.meshgrid(torch.arange(width), torch.arange(height)) + valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \ + ((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2)) + valid_block = torch.reshape(valid_block, (1, 1, height, width)) + valid_block = valid_block.to(x.dtype) + + uniform_noise = torch.rand_like(x) + block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).to(x.dtype) + block_mask = -F.max_pool2d( + -block_mask, + kernel_size=clipped_block_size, # block_size, + stride=1, + padding=clipped_block_size // 2) + + if drop_with_noise: + normal_noise = torch.randn_like(x) + x = x * block_mask + normal_noise * (1 - block_mask) + else: + normalize_scale = block_mask.numel() / (torch.sum(block_mask, dtype=torch.float32) + 1e-7) + x = x * block_mask * normalize_scale + return x + + +class DropBlock2d(nn.Module): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ + def __init__(self, + drop_prob=0.1, + block_size=7, + gamma_scale=1.0, + with_noise=False): + super(DropBlock2d, self).__init__() + self.drop_prob = drop_prob + self.gamma_scale = gamma_scale + self.block_size = block_size + self.with_noise = with_noise + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + return drop_block_2d(x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise) + + +def drop_path(x, drop_prob=0.): + """Drop paths (Stochastic Depth) per sample (when applied in residual blocks).""" + keep_prob = 1 - drop_prob + random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.ModuleDict): + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + if not self.training or not self.drop_prob: + return x + return drop_path(x, self.drop_prob) From ef457555d3fdf53ba9ec7765bf9477c5c86b84e6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 29 Jan 2020 14:34:45 -0800 Subject: [PATCH 09/23] BlockDrop working on GPU --- timm/models/nn_ops.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/timm/models/nn_ops.py b/timm/models/nn_ops.py index 37286611..9b931efb 100644 --- a/timm/models/nn_ops.py +++ b/timm/models/nn_ops.py @@ -83,14 +83,13 @@ def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noi (height - block_size + 1)) # Forces the block to be inside the feature map. - w_i, h_i = torch.meshgrid(torch.arange(width), torch.arange(height)) + w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device)) valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \ ((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2)) - valid_block = torch.reshape(valid_block, (1, 1, height, width)) - valid_block = valid_block.to(x.dtype) + valid_block = torch.reshape(valid_block, (1, 1, height, width)).float() - uniform_noise = torch.rand_like(x) - block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).to(x.dtype) + uniform_noise = torch.rand_like(x, dtype=torch.float32) + block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) block_mask = -F.max_pool2d( -block_mask, kernel_size=clipped_block_size, # block_size, From 355aa152d5c89f210ae0771f800598409807dacf Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 29 Jan 2020 14:51:34 -0800 Subject: [PATCH 10/23] Just leave it float for now, will look at fp16 later. Remove unused reference code. --- timm/models/nn_ops.py | 72 ++----------------------------------------- 1 file changed, 2 insertions(+), 70 deletions(-) diff --git a/timm/models/nn_ops.py b/timm/models/nn_ops.py index 9b931efb..30b98427 100644 --- a/timm/models/nn_ops.py +++ b/timm/models/nn_ops.py @@ -4,74 +4,6 @@ import torch.nn.functional as F import numpy as np import math -## Assembled CNN Tensorflow Impl -# -# def _bernoulli(shape, mean, seed=None, dtype=tf.float32): -# return tf.nn.relu(tf.sign(mean - tf.random_uniform(shape, minval=0, maxval=1, dtype=dtype, seed=seed))) -# -# -# def dropblock(x, keep_prob, block_size, gamma_scale=1.0, seed=None, name=None, -# data_format='channels_last', is_training=True): # pylint: disable=invalid-name -# """ -# Dropblock layer. For more details, refer to https://arxiv.org/abs/1810.12890 -# :param x: A floating point tensor. -# :param keep_prob: A scalar Tensor with the same type as x. The probability that each element is kept. -# :param block_size: The block size to drop -# :param gamma_scale: The multiplier to gamma. -# :param seed: Python integer. Used to create random seeds. -# :param name: A name for this operation (optional) -# :param data_format: 'channels_last' or 'channels_first' -# :param is_training: If False, do nothing. -# :return: A Tensor of the same shape of x. -# """ -# if not is_training: -# return x -# -# # Early return if nothing needs to be dropped. -# if (isinstance(keep_prob, float) and keep_prob == 1) or gamma_scale == 0: -# return x -# -# with tf.name_scope(name, "dropblock", [x]) as name: -# if not x.dtype.is_floating: -# raise ValueError("x has to be a floating point tensor since it's going to" -# " be scaled. Got a %s tensor instead." % x.dtype) -# if isinstance(keep_prob, float) and not 0 < keep_prob <= 1: -# raise ValueError("keep_prob must be a scalar tensor or a float in the " -# "range (0, 1], got %g" % keep_prob) -# -# br = (block_size - 1) // 2 -# tl = (block_size - 1) - br -# if data_format == 'channels_last': -# _, h, w, c = x.shape.as_list() -# sampling_mask_shape = tf.stack([1, h - block_size + 1, w - block_size + 1, c]) -# pad_shape = [[0, 0], [tl, br], [tl, br], [0, 0]] -# elif data_format == 'channels_first': -# _, c, h, w = x.shape.as_list() -# sampling_mask_shape = tf.stack([1, c, h - block_size + 1, w - block_size + 1]) -# pad_shape = [[0, 0], [0, 0], [tl, br], [tl, br]] -# else: -# raise NotImplementedError -# -# gamma = (1. - keep_prob) * (w * h) / (block_size ** 2) / ((w - block_size + 1) * (h - block_size + 1)) -# gamma = gamma_scale * gamma -# mask = _bernoulli(sampling_mask_shape, gamma, seed, tf.float32) -# mask = tf.pad(mask, pad_shape) -# -# xdtype_mask = tf.cast(mask, x.dtype) -# xdtype_mask = tf.layers.max_pooling2d( -# inputs=xdtype_mask, pool_size=block_size, -# strides=1, padding='SAME', -# data_format=data_format) -# -# xdtype_mask = 1 - xdtype_mask -# fp32_mask = tf.cast(xdtype_mask, tf.float32) -# ret = tf.multiply(x, xdtype_mask) -# float32_mask_size = tf.cast(tf.size(fp32_mask), tf.float32) -# float32_mask_reduce_sum = tf.reduce_sum(fp32_mask) -# normalize_factor = tf.cast(float32_mask_size / (float32_mask_reduce_sum + 1e-8), x.dtype) -# ret = ret * normalize_factor -# return ret - def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False): _, _, height, width = x.shape @@ -89,7 +21,7 @@ def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noi valid_block = torch.reshape(valid_block, (1, 1, height, width)).float() uniform_noise = torch.rand_like(x, dtype=torch.float32) - block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) + block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float() block_mask = -F.max_pool2d( -block_mask, kernel_size=clipped_block_size, # block_size, @@ -100,7 +32,7 @@ def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noi normal_noise = torch.randn_like(x) x = x * block_mask + normal_noise * (1 - block_mask) else: - normalize_scale = block_mask.numel() / (torch.sum(block_mask, dtype=torch.float32) + 1e-7) + normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7) x = x * block_mask * normalize_scale return x From a9d2424fd1680590146bcd4eed912cc84bbe6a5e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 Jan 2020 16:51:49 -0800 Subject: [PATCH 11/23] Add separate zero_init_last_bn function to support more block variety without a mess --- timm/models/res2net.py | 3 ++ timm/models/resnet.py | 86 +++++++++++++++++++++++++++--------------- timm/models/sknet.py | 6 +++ 3 files changed, 64 insertions(+), 31 deletions(-) diff --git a/timm/models/res2net.py b/timm/models/res2net.py index c83aba62..bcb7eaaf 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -87,6 +87,9 @@ class Bottle2neck(nn.Module): self.relu = act_layer(inplace=True) self.downsample = downsample + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + def forward(self, x): residual = x diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 3e0ce23e..d97a6aad 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -156,26 +156,38 @@ class BasicBlock(nn.Module): self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) def forward(self, x): residual = x - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.bn2(out) + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) if self.se is not None: - out = self.se(out) + x = self.se(x) - if self.downsample is not None: - residual = self.downsample(x) + if self.drop_path is not None: + x = self.drop_path(x) - out += residual - out = self.act2(out) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act2(x) - return out + return x class Bottleneck(nn.Module): @@ -207,31 +219,44 @@ class Bottleneck(nn.Module): self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) def forward(self, x): residual = x - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) - out = self.conv2(out) - out = self.bn2(out) - out = self.act2(out) + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act2(x) - out = self.conv3(out) - out = self.bn3(out) + x = self.conv3(x) + x = self.bn3(x) + if self.drop_block is not None: + x = self.drop_block(x) if self.se is not None: - out = self.se(out) + x = self.se(x) - if self.downsample is not None: - residual = self.downsample(x) + if self.drop_path is not None: + x = self.drop_path(x) - out += residual - out = self.act3(out) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act3(x) - return out + return x class ResNet(nn.Module): @@ -367,17 +392,16 @@ class ResNet(nn.Module): self.num_features = 512 * block.expansion self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - last_bn_name = 'bn3' if 'Bottle' in block.__name__ else 'bn2' for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - if zero_init_last_bn and 'layer' in n and last_bn_name in n: - # Initialize weight/gamma of last BN in each residual block to zero - nn.init.constant_(m.weight, 0.) - else: - nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, use_se=False, avg_down=False, down_kernel_size=1, **kwargs): diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 41e19075..7cf4fbe6 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -63,6 +63,9 @@ class SelectiveKernelBasic(nn.Module): self.drop_block = drop_block self.drop_path = drop_path + def zero_init_last_bn(self): + nn.init.zeros_(self.conv2.bn.weight) + def forward(self, x): residual = x x = self.conv1(x) @@ -109,6 +112,9 @@ class SelectiveKernelBottleneck(nn.Module): self.drop_block = drop_block self.drop_path = drop_path + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + def forward(self, x): residual = x x = self.conv1(x) From 7d07ebb66075db35560befc921b8f3098e7f79f0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 1 Feb 2020 23:28:48 -0800 Subject: [PATCH 12/23] Adding some configs to sknet, incl ResNet50 variants from 'Compounding ... Assembled Techniques' paper and original SKNet50 --- timm/models/sknet.py | 84 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 7cf4fbe6..0c387e39 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -22,7 +22,10 @@ def _cfg(url='', **kwargs): default_cfgs = { 'skresnet18': _cfg(url=''), - 'skresnet26d': _cfg() + 'skresnet26d': _cfg(), + 'skresnet50': _cfg(), + 'skresnet50d': _cfg(), + 'skresnext50_32x4d': _cfg(), } @@ -131,6 +134,41 @@ class SelectiveKernelBottleneck(nn.Module): return x +@register_model +def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['skresnet18'] + sk_kwargs = dict( + min_attn_channels=16, + split_input=True + ) + model = ResNet( + SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-26 model. @@ -150,15 +188,17 @@ def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNet-18 model. +def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Select Kernel ResNet-50 model. + Based on config in "Compounding the Performance Improvements of Assembled Techniques in a + Convolutional Neural Network" """ - default_cfg = default_cfgs['skresnet18'] sk_kwargs = dict( - min_attn_channels=16, + attn_reduction=2, ) + default_cfg = default_cfgs['skresnet50'] model = ResNet( - SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, + SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), **kwargs) model.default_cfg = default_cfg if pretrained: @@ -167,18 +207,34 @@ def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNet-18 model. +def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Select Kernel ResNet-50-D model. + Based on config in "Compounding the Performance Improvements of Assembled Techniques in a + Convolutional Neural Network" """ - default_cfg = default_cfgs['skresnet18'] sk_kwargs = dict( - min_attn_channels=16, - split_input=True + attn_reduction=2, ) + default_cfg = default_cfgs['skresnet50d'] model = ResNet( - SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, - block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to + the SKNet50 model in the Select Kernel Paper + """ + default_cfg = default_cfgs['skresnext50_32x4d'] + model = ResNet( + SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, + num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) - return model \ No newline at end of file + return model From 13e8da2b46d8b48fa4bdc76dd89cd7aaf3f7d615 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 7 Feb 2020 22:42:04 -0800 Subject: [PATCH 13/23] SelectKernel split_input works best when input channels split like grouped conv, but output is full width. Disable zero_init for SK nets, seems a bad combo. --- timm/models/conv2d_layers.py | 22 +++++++--------------- timm/models/sknet.py | 12 +++++++----- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py index 5b1a44e8..feaf653c 100644 --- a/timm/models/conv2d_layers.py +++ b/timm/models/conv2d_layers.py @@ -311,15 +311,13 @@ class SelectiveKernelConv(nn.Module): kernel_size = [3] * len(kernel_size) else: dilation = [dilation] * len(kernel_size) - num_paths = len(kernel_size) - self.num_paths = num_paths - self.split_input = split_input + self.num_paths = len(kernel_size) self.in_channels = in_channels self.out_channels = out_channels - if split_input: - assert in_channels % num_paths == 0 and out_channels % num_paths == 0 - in_channels = in_channels // num_paths - out_channels = out_channels // num_paths + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths groups = min(out_channels, groups) conv_kwargs = dict( @@ -329,7 +327,7 @@ class SelectiveKernelConv(nn.Module): for k, d in zip(kernel_size, dilation)]) attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) - self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) self.drop_block = drop_block def forward(self, x): @@ -338,16 +336,10 @@ class SelectiveKernelConv(nn.Module): x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] else: x_paths = [op(x) for op in self.paths] - x = torch.stack(x_paths, dim=1) x_attn = self.attn(x) x = x * x_attn - - if self.split_input: - B, N, C, H, W = x.shape - x = x.reshape(B, N * C, H, W) - else: - x = torch.sum(x, dim=1) + x = torch.sum(x, dim=1) return x diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 0c387e39..4b02d501 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -158,11 +158,12 @@ def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['skresnet18'] sk_kwargs = dict( min_attn_channels=16, + attn_reduction=8, split_input=True ) model = ResNet( SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, - block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -179,7 +180,7 @@ def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): ) model = ResNet( SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False **kwargs) model.default_cfg = default_cfg if pretrained: @@ -199,7 +200,7 @@ def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['skresnet50'] model = ResNet( SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, - block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -218,7 +219,8 @@ def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['skresnet50d'] model = ResNet( SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -233,7 +235,7 @@ def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['skresnext50_32x4d'] model = ResNet( SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) + num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) From cade8291052c040e9b62be884f7a446ef6c82dd7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 9 Feb 2020 11:04:48 -0800 Subject: [PATCH 14/23] Add EfficientNet-ES to sotabench --- sotabench.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sotabench.py b/sotabench.py index bd5b0b81..66b7d323 100644 --- a/sotabench.py +++ b/sotabench.py @@ -54,6 +54,8 @@ model_list = [ model_desc='Trained from scratch in PyTorch w/ RandAugment'), _entry('efficientnet_b3a', 'EfficientNet-B3 (320x320, 1.0 crop)', '1905.11946', model_desc='Trained from scratch in PyTorch w/ RandAugment'), + _entry('efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946', + model_desc='Trained from scratch in PyTorch w/ RandAugment'), _entry('fbnetc_100', 'FBNet-C', '1812.03443', model_desc='Trained in PyTorch with RMSProp, exponential LR decay'), _entry('gluon_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from GluonCV Model Zoo'), From d0eb59ef467d223aaee8c5e421c2cf9b1a2b8929 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 9 Feb 2020 11:32:05 -0800 Subject: [PATCH 15/23] Remove unused default_init for EfficientNets, experimenting with fanout calc for #84 --- timm/models/efficientnet_builder.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index db6f54f9..ca2060c4 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -358,15 +358,24 @@ class EfficientNetBuilder: return stages -def _init_weight_goog(m, n=''): +def _init_weight_goog(m, n='', fix_group_fanout=False): """ Weight initialization as per Tensorflow official implementations. + Args: + m (nn.Module): module to init + n (str): module name + fix_group_fanout (bool): enable correct fanout calculation w/ group convs + + FIXME change fix_group_fanout to default to True if experiments show better training results + Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py """ if isinstance(m, CondConv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups init_weight_fn = get_condconv_initializer( lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) init_weight_fn(m.weight) @@ -374,6 +383,8 @@ def _init_weight_goog(m, n=''): m.bias.data.zero_() elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() @@ -390,21 +401,6 @@ def _init_weight_goog(m, n=''): m.bias.data.zero_() -def _init_weight_default(m, n=''): - """ Basic ResNet (Kaiming) style weight init""" - if isinstance(m, CondConv2d): - init_fn = get_condconv_initializer(partial( - nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) - init_fn(m.weight) - elif isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') - - def efficientnet_init_weights(model: nn.Module, init_fn=None): init_fn = init_fn or _init_weight_goog for n, m in model.named_modules(): From 7011cd0902009baaa3ceb4754fe7f379361b8a6a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 9 Feb 2020 12:41:18 -0800 Subject: [PATCH 16/23] A little bit of ECA cleanup --- timm/models/EcaModule.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/timm/models/EcaModule.py b/timm/models/EcaModule.py index b91b5801..fab205cb 100644 --- a/timm/models/EcaModule.py +++ b/timm/models/EcaModule.py @@ -1,14 +1,16 @@ -''' +""" ECA module from ECAnet -original paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks + +paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks https://arxiv.org/abs/1910.03151 -https://github.com/BangguWu/ECANet -original ECA model borrowed from original github -modified circular ECA implementation and -adoptation for use in pytorch image models package +Original ECA model borrowed from https://github.com/BangguWu/ECANet + +Modified circular ECA implementation and adaption for use in timm package by Chris Ha https://github.com/VRandme +Original License: + MIT License Copyright (c) 2019 BangguWu, Qilong Wang @@ -30,13 +32,14 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -''' +""" import math from torch import nn import torch.nn.functional as F + class EcaModule(nn.Module): - """Constructs a ECA module. + """Constructs an ECA module. Args: channel: Number of channels of the input feature map for use in adaptive kernel sizes @@ -59,9 +62,9 @@ class EcaModule(nn.Module): self.sigmoid = nn.Sigmoid() def forward(self, x): - # feature descriptor on the global spatial information + # Feature descriptor on the global spatial information y = self.avg_pool(x) - # reshape for convolution + # Reshape for convolution y = y.view(x.shape[0], 1, -1) # Two different branches of ECA module y = self.conv(y) @@ -69,10 +72,12 @@ class EcaModule(nn.Module): y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) return x * y.expand_as(x) + class CecaModule(nn.Module): """Constructs a circular ECA module. - the primary difference is that the conv uses a circular padding rather than zero padding. - This is because unlike images, the channels themselves do not have inherent ordering nor + + ECA module where the conv uses circular padding rather than zero padding. + Unlike the spatial dimension, the channels do not have inherent ordering nor locality. Although this module in essence, applies such an assumption, it is unnecessary to limit the channels on either "edge" from being circularly adapted to each other. This will fundamentally increase connectivity and possibly increase performance metrics @@ -97,7 +102,7 @@ class CecaModule(nn.Module): k_size = t if t % 2 else t + 1 self.avg_pool = nn.AdaptiveAvgPool2d(1) - #pytorch circular padding mode is bugged as of pytorch 1.4 + #pytorch circular padding mode is buggy as of pytorch 1.4 #see https://github.com/pytorch/pytorch/pull/17240 #implement manual circular padding @@ -106,10 +111,10 @@ class CecaModule(nn.Module): self.sigmoid = nn.Sigmoid() def forward(self, x): - # feature descriptor on the global spatial information + # Feature descriptor on the global spatial information y = self.avg_pool(x) - #manually implement circular padding, F.pad does not seemed to be bugged + # Manually implement circular padding, F.pad does not seemed to be bugged y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular') # Two different branches of ECA module From 4defbbbaa89049f1805951cda9db15f832c32695 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 9 Feb 2020 12:44:26 -0800 Subject: [PATCH 17/23] Fix module name mistake, start layers sub-package --- timm/models/layers/__init__.py | 1 + timm/models/{EcaModule.py => layers/eca.py} | 0 timm/models/resnet.py | 4 ++-- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 timm/models/layers/__init__.py rename timm/models/{EcaModule.py => layers/eca.py} (100%) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py new file mode 100644 index 00000000..325516e9 --- /dev/null +++ b/timm/models/layers/__init__.py @@ -0,0 +1 @@ +from .eca import EcaModule, CecaModule diff --git a/timm/models/EcaModule.py b/timm/models/layers/eca.py similarity index 100% rename from timm/models/EcaModule.py rename to timm/models/layers/eca.py diff --git a/timm/models/resnet.py b/timm/models/resnet.py index da755373..893350ef 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d -from .EcaModule import EcaModule +from .layers import EcaModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -212,7 +212,7 @@ class Bottleneck(nn.Module): self.bn3 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None - self.eca = Eca_Module(outplanes) if use_eca else None + self.eca = EcaModule(outplanes) if use_eca else None self.act3 = act_layer(inplace=True) self.downsample = downsample From 13746a33fcdd787d2d8fa5c6b729362ebfdddc7a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 9 Feb 2020 13:13:08 -0800 Subject: [PATCH 18/23] Big move, layer modules and fn to timm/models/layers --- timm/models/__init__.py | 4 ++-- timm/models/densenet.py | 2 +- timm/models/dla.py | 2 +- timm/models/dpn.py | 2 +- timm/models/efficientnet.py | 3 +-- timm/models/efficientnet_blocks.py | 4 ++-- timm/models/efficientnet_builder.py | 2 +- timm/models/gluon_xception.py | 2 +- timm/models/hrnet.py | 2 +- timm/models/inception_resnet_v2.py | 2 +- timm/models/inception_v4.py | 2 +- timm/models/layers/__init__.py | 7 +++++++ timm/models/{ => layers}/activations.py | 0 timm/models/{ => layers}/adaptive_avgmax_pool.py | 0 timm/models/{ => layers}/conv2d_layers.py | 0 timm/models/{ => layers}/median_pool.py | 0 timm/models/{ => layers}/nn_ops.py | 0 timm/models/{ => layers}/split_batchnorm.py | 0 timm/models/{ => layers}/test_time_pool.py | 0 timm/models/mobilenetv3.py | 5 ++--- timm/models/nasnet.py | 2 +- timm/models/pnasnet.py | 2 +- timm/models/res2net.py | 2 +- timm/models/resnet.py | 4 +--- timm/models/selecsls.py | 2 +- timm/models/senet.py | 2 +- timm/models/sknet.py | 8 ++++---- timm/models/xception.py | 2 +- 28 files changed, 33 insertions(+), 30 deletions(-) rename timm/models/{ => layers}/activations.py (100%) rename timm/models/{ => layers}/adaptive_avgmax_pool.py (100%) rename timm/models/{ => layers}/conv2d_layers.py (100%) rename timm/models/{ => layers}/median_pool.py (100%) rename timm/models/{ => layers}/nn_ops.py (100%) rename timm/models/{ => layers}/split_batchnorm.py (100%) rename timm/models/{ => layers}/test_time_pool.py (100%) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 69e09085..cc4d470e 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -21,5 +21,5 @@ from .sknet import * from .registry import * from .factory import create_model from .helpers import load_checkpoint, resume_checkpoint -from .test_time_pool import TestTimePoolHead, apply_test_time_pool -from .split_batchnorm import convert_splitbn_model +from .layers import TestTimePoolHead, apply_test_time_pool +from .layers import convert_splitbn_model diff --git a/timm/models/densenet.py b/timm/models/densenet.py index d1ac5857..4235c0f7 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD import re diff --git a/timm/models/dla.py b/timm/models/dla.py index cd560f44..a9e81d16 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 7f46e8e0..fd58e516 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -16,7 +16,7 @@ from collections import OrderedDict from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 8d07a2ca..7261fe10 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -27,8 +27,7 @@ from .efficientnet_builder import * from .feature_hooks import FeatureHooks from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d -from .conv2d_layers import select_conv2d +from .layers import SelectAdaptivePool2d, select_conv2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 13ab051a..78d451be 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -4,8 +4,8 @@ from functools import partial import torch import torch.nn as nn import torch.nn.functional as F -from .activations import sigmoid -from .conv2d_layers import * +from .layers.activations import sigmoid +from .layers.conv2d_layers import * # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index ca2060c4..b159eefe 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -5,7 +5,7 @@ from collections.__init__ import OrderedDict from copy import deepcopy import torch.nn as nn -from .activations import sigmoid, HardSwish, Swish +from .layers.activations import sigmoid, HardSwish, Swish from .efficientnet_blocks import * diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 5a35d226..2fc8e699 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -13,7 +13,7 @@ from collections import OrderedDict from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['Xception65', 'Xception71'] diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 99a2bd91..16df5bc1 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -25,7 +25,7 @@ import torch.nn.functional as F from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD _BN_MOMENTUM = 0.1 diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index 285863f5..13ad0e9d 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD __all__ = ['InceptionResnetV2'] diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 8c3dee86..16080554 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD __all__ = ['InceptionV4'] diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 325516e9..8e9fcae2 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -1 +1,8 @@ +from .conv2d_layers import select_conv2d, MixedConv2d, CondConv2d, ConvBnAct, SelectiveKernelConv from .eca import EcaModule, CecaModule +from .activations import * +from .adaptive_avgmax_pool import \ + adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d +from .nn_ops import DropBlock2d, DropPath +from .test_time_pool import TestTimePoolHead, apply_test_time_pool +from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model diff --git a/timm/models/activations.py b/timm/models/layers/activations.py similarity index 100% rename from timm/models/activations.py rename to timm/models/layers/activations.py diff --git a/timm/models/adaptive_avgmax_pool.py b/timm/models/layers/adaptive_avgmax_pool.py similarity index 100% rename from timm/models/adaptive_avgmax_pool.py rename to timm/models/layers/adaptive_avgmax_pool.py diff --git a/timm/models/conv2d_layers.py b/timm/models/layers/conv2d_layers.py similarity index 100% rename from timm/models/conv2d_layers.py rename to timm/models/layers/conv2d_layers.py diff --git a/timm/models/median_pool.py b/timm/models/layers/median_pool.py similarity index 100% rename from timm/models/median_pool.py rename to timm/models/layers/median_pool.py diff --git a/timm/models/nn_ops.py b/timm/models/layers/nn_ops.py similarity index 100% rename from timm/models/nn_ops.py rename to timm/models/layers/nn_ops.py diff --git a/timm/models/split_batchnorm.py b/timm/models/layers/split_batchnorm.py similarity index 100% rename from timm/models/split_batchnorm.py rename to timm/models/layers/split_batchnorm.py diff --git a/timm/models/test_time_pool.py b/timm/models/layers/test_time_pool.py similarity index 100% rename from timm/models/test_time_pool.py rename to timm/models/layers/test_time_pool.py diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index a6b67532..76f6363c 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -11,11 +11,10 @@ import torch.nn as nn import torch.nn.functional as F from .efficientnet_builder import * -from .activations import HardSwish, hard_sigmoid from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d -from .conv2d_layers import select_conv2d +from .layers import SelectAdaptivePool2d, select_conv2d +from .layers.activations import HardSwish, hard_sigmoid from .feature_hooks import FeatureHooks from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 009c62d3..8847b1de 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d __all__ = ['NASNetALarge'] diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 396e6157..dc9b3e20 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d __all__ = ['PNASNet5Large'] diff --git a/timm/models/res2net.py b/timm/models/res2net.py index bcb7eaaf..b8d31b3e 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -11,7 +11,7 @@ import torch.nn.functional as F from .resnet import ResNet, SEModule from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = [] diff --git a/timm/models/resnet.py b/timm/models/resnet.py index a1c593ae..528d5790 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -13,9 +13,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d -from .layers import EcaModule -from .nn_ops import DropBlock2d, DropPath +from .layers import EcaModule, SelectAdaptivePool2d, DropBlock2d, DropPath from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 17796700..2f369e99 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this diff --git a/timm/models/senet.py b/timm/models/senet.py index 90ef5ae1..efbf4657 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['SENet'] diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 4b02d501..032b7d0b 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -2,10 +2,10 @@ import math from torch import nn as nn -from timm.models.registry import register_model -from timm.models.helpers import load_pretrained -from timm.models.conv2d_layers import SelectiveKernelConv, ConvBnAct -from timm.models.resnet import ResNet, SEModule +from .registry import register_model +from .helpers import load_pretrained +from .layers import SelectiveKernelConv, ConvBnAct +from .resnet import ResNet, SEModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD diff --git a/timm/models/xception.py b/timm/models/xception.py index 2dc334fa..cb98bbc9 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -29,7 +29,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .layers import SelectAdaptivePool2d __all__ = ['Xception'] From a99ec4e7d16b45255991f513fdbd2be76abfe598 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 9 Feb 2020 14:46:28 -0800 Subject: [PATCH 19/23] A bunch more layer reorg, splitting many layers into own files. Improve torchscript compatibility. --- timm/models/efficientnet.py | 3 +- timm/models/efficientnet_blocks.py | 35 ++- timm/models/efficientnet_builder.py | 3 +- timm/models/layers/__init__.py | 8 +- timm/models/layers/activations.py | 35 ++- timm/models/layers/cond_conv2d.py | 118 +++++++ timm/models/layers/conv2d_layers.py | 361 ---------------------- timm/models/layers/conv2d_same.py | 79 +++++ timm/models/layers/conv_bn_act.py | 32 ++ timm/models/layers/conv_helpers.py | 27 ++ timm/models/layers/{nn_ops.py => drop.py} | 0 timm/models/layers/mixed_conv2d.py | 49 +++ timm/models/layers/select_conv2d.py | 30 ++ timm/models/layers/selective_kernel.py | 88 ++++++ timm/models/layers/test_time_pool.py | 5 + timm/models/mobilenetv3.py | 2 - 16 files changed, 479 insertions(+), 396 deletions(-) create mode 100644 timm/models/layers/cond_conv2d.py delete mode 100644 timm/models/layers/conv2d_layers.py create mode 100644 timm/models/layers/conv2d_same.py create mode 100644 timm/models/layers/conv_bn_act.py create mode 100644 timm/models/layers/conv_helpers.py rename timm/models/layers/{nn_ops.py => drop.py} (100%) create mode 100644 timm/models/layers/mixed_conv2d.py create mode 100644 timm/models/layers/select_conv2d.py create mode 100644 timm/models/layers/selective_kernel.py diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 7261fe10..c5dcacd3 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -27,7 +27,8 @@ from .efficientnet_builder import * from .feature_hooks import FeatureHooks from .registry import register_model from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d, select_conv2d +from .layers import SelectAdaptivePool2d +from timm.models.layers import select_conv2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 78d451be..a231fa31 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -1,11 +1,8 @@ - -from functools import partial - import torch import torch.nn as nn -import torch.nn.functional as F +from torch.nn import functional as F from .layers.activations import sigmoid -from .layers.conv2d_layers import * +from .layers import select_conv2d # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per @@ -72,7 +69,7 @@ def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): return make_divisible(channels, divisor, channel_min) -def drop_connect(inputs, training=False, drop_connect_rate=0.): +def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.): """Apply drop connect.""" if not training: return inputs @@ -160,7 +157,7 @@ class DepthwiseSeparableConv(nn.Module): norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): super(DepthwiseSeparableConv, self).__init__() norm_kwargs = norm_kwargs or {} - self.has_se = se_ratio is not None and se_ratio > 0. + has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv self.drop_connect_rate = drop_connect_rate @@ -171,9 +168,11 @@ class DepthwiseSeparableConv(nn.Module): self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - if self.has_se: + if has_se: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) @@ -193,7 +192,7 @@ class DepthwiseSeparableConv(nn.Module): x = self.bn1(x) x = self.act1(x) - if self.has_se: + if self.se is not None: x = self.se(x) x = self.conv_pw(x) @@ -219,7 +218,7 @@ class InvertedResidual(nn.Module): norm_kwargs = norm_kwargs or {} conv_kwargs = conv_kwargs or {} mid_chs = make_divisible(in_chs * exp_ratio) - self.has_se = se_ratio is not None and se_ratio > 0. + has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_connect_rate = drop_connect_rate @@ -236,9 +235,11 @@ class InvertedResidual(nn.Module): self.act2 = act_layer(inplace=True) # Squeeze-and-excitation - if self.has_se: + if has_se: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None # Point-wise linear projection self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) @@ -269,7 +270,7 @@ class InvertedResidual(nn.Module): x = self.act2(x) # Squeeze-and-excitation - if self.has_se: + if self.se is not None: x = self.se(x) # Point-wise linear projection @@ -323,7 +324,7 @@ class CondConvResidual(InvertedResidual): x = self.act2(x) # Squeeze-and-excitation - if self.has_se: + if self.se is not None: x = self.se(x) # Point-wise linear projection @@ -350,7 +351,7 @@ class EdgeResidual(nn.Module): mid_chs = make_divisible(fake_in_chs * exp_ratio) else: mid_chs = make_divisible(in_chs * exp_ratio) - self.has_se = se_ratio is not None and se_ratio > 0. + has_se = se_ratio is not None and se_ratio > 0. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip self.drop_connect_rate = drop_connect_rate @@ -360,9 +361,11 @@ class EdgeResidual(nn.Module): self.act1 = act_layer(inplace=True) # Squeeze-and-excitation - if self.has_se: + if has_se: se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = None # Point-wise linear projection self.conv_pwl = select_conv2d( @@ -389,7 +392,7 @@ class EdgeResidual(nn.Module): x = self.act1(x) # Squeeze-and-excitation - if self.has_se: + if self.se is not None: x = self.se(x) # Point-wise linear projection diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index b159eefe..954420fb 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -5,7 +5,8 @@ from collections.__init__ import OrderedDict from copy import deepcopy import torch.nn as nn -from .layers.activations import sigmoid, HardSwish, Swish +from .layers import CondConv2d, get_condconv_initializer +from .layers.activations import HardSwish, Swish from .efficientnet_blocks import * diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 8e9fcae2..79aa9ac2 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -1,8 +1,12 @@ -from .conv2d_layers import select_conv2d, MixedConv2d, CondConv2d, ConvBnAct, SelectiveKernelConv +from .conv_bn_act import ConvBnAct +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d, get_condconv_initializer +from .select_conv2d import select_conv2d +from .selective_kernel import SelectiveKernelConv from .eca import EcaModule, CecaModule from .activations import * from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d -from .nn_ops import DropBlock2d, DropPath +from .drop import DropBlock2d, DropPath from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model diff --git a/timm/models/layers/activations.py b/timm/models/layers/activations.py index aafa290c..165b7951 100644 --- a/timm/models/layers/activations.py +++ b/timm/models/layers/activations.py @@ -1,9 +1,18 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Hacked together by Ross Wightman +""" + + import torch from torch import nn as nn from torch.nn import functional as F -_USE_MEM_EFFICIENT_ISH = True +_USE_MEM_EFFICIENT_ISH = False if _USE_MEM_EFFICIENT_ISH: # This version reduces memory overhead of Swish during training by # recomputing torch.sigmoid(x) in backward instead of saving it. @@ -66,20 +75,20 @@ if _USE_MEM_EFFICIENT_ISH: return MishJitAutoFn.apply(x) else: - def swish(x, inplace=False): + def swish(x, inplace: bool = False): """Swish - Described in: https://arxiv.org/abs/1710.05941 """ return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) - def mish(x, _inplace=False): + def mish(x, _inplace: bool = False): """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 """ return x.mul(F.softplus(x).tanh()) class Swish(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(Swish, self).__init__() self.inplace = inplace @@ -88,7 +97,7 @@ class Swish(nn.Module): class Mish(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(Mish, self).__init__() self.inplace = inplace @@ -96,13 +105,13 @@ class Mish(nn.Module): return mish(x, self.inplace) -def sigmoid(x, inplace=False): +def sigmoid(x, inplace: bool = False): return x.sigmoid_() if inplace else x.sigmoid() # PyTorch has this, but not with a consistent inplace argmument interface class Sigmoid(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(Sigmoid, self).__init__() self.inplace = inplace @@ -110,13 +119,13 @@ class Sigmoid(nn.Module): return x.sigmoid_() if self.inplace else x.sigmoid() -def tanh(x, inplace=False): +def tanh(x, inplace: bool = False): return x.tanh_() if inplace else x.tanh() # PyTorch has this, but not with a consistent inplace argmument interface class Tanh(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(Tanh, self).__init__() self.inplace = inplace @@ -124,13 +133,13 @@ class Tanh(nn.Module): return x.tanh_() if self.inplace else x.tanh() -def hard_swish(x, inplace=False): +def hard_swish(x, inplace: bool = False): inner = F.relu6(x + 3.).div_(6.) return x.mul_(inner) if inplace else x.mul(inner) class HardSwish(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(HardSwish, self).__init__() self.inplace = inplace @@ -138,7 +147,7 @@ class HardSwish(nn.Module): return hard_swish(x, self.inplace) -def hard_sigmoid(x, inplace=False): +def hard_sigmoid(x, inplace: bool = False): if inplace: return x.add_(3.).clamp_(0., 6.).div_(6.) else: @@ -146,7 +155,7 @@ def hard_sigmoid(x, inplace=False): class HardSigmoid(nn.Module): - def __init__(self, inplace=False): + def __init__(self, inplace: bool = False): super(HardSigmoid, self).__init__() self.inplace = inplace diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py new file mode 100644 index 00000000..d6cba889 --- /dev/null +++ b/timm/models/layers/cond_conv2d.py @@ -0,0 +1,118 @@ +""" Conditional Convolution + +Hacked together by Ross Wightman +""" + +import math +from functools import partial +import numpy as np +import torch +from torch import nn as nn +from torch.nn import functional as F + +from .conv2d_same import get_padding_value, conv2d_same +from .conv_helpers import tup_pair + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditional Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = tup_pair(kernel_size) + self.stride = tup_pair(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = tup_pair(padding_val) + self.dilation = tup_pair(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out diff --git a/timm/models/layers/conv2d_layers.py b/timm/models/layers/conv2d_layers.py deleted file mode 100644 index feaf653c..00000000 --- a/timm/models/layers/conv2d_layers.py +++ /dev/null @@ -1,361 +0,0 @@ -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch._six import container_abcs -from itertools import repeat -from functools import partial -import numpy as np -import math - - -# Tuple helpers ripped from PyTorch -def _ntuple(n): - def parse(x): - if isinstance(x, container_abcs.Iterable): - return x - return tuple(repeat(x, n)) - return parse - - -_single = _ntuple(1) -_pair = _ntuple(2) -_triple = _ntuple(3) -_quadruple = _ntuple(4) - - -def _is_static_pad(kernel_size, stride=1, dilation=1, **_): - return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 - - -def _get_padding(kernel_size, stride=1, dilation=1, **_): - padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 - return padding - - -def _calc_same_pad(i, k, s, d): - return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) - - -def _split_channels(num_chan, num_groups): - split = [num_chan // num_groups for _ in range(num_groups)] - split[0] += num_chan - sum(split) - return split - - -# pylint: disable=unused-argument -def conv2d_same(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1): - ih, iw = x.size()[-2:] - kh, kw = weight.size()[-2:] - pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) - pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) - return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) - - -class Conv2dSame(nn.Conv2d): - """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions - """ - - # pylint: disable=unused-argument - def __init__(self, in_channels, out_channels, kernel_size, stride=1, - padding=0, dilation=1, groups=1, bias=True): - super(Conv2dSame, self).__init__( - in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) - - def forward(self, x): - return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) - - -def get_padding_value(padding, kernel_size, **kwargs): - dynamic = False - if isinstance(padding, str): - # for any string padding, the padding will be calculated for you, one of three ways - padding = padding.lower() - if padding == 'same': - # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact - if _is_static_pad(kernel_size, **kwargs): - # static case, no extra overhead - padding = _get_padding(kernel_size, **kwargs) - else: - # dynamic 'SAME' padding, has runtime/GPU memory overhead - padding = 0 - dynamic = True - elif padding == 'valid': - # 'VALID' padding, same as padding=0 - padding = 0 - else: - # Default to PyTorch style 'same'-ish symmetric padding - padding = _get_padding(kernel_size, **kwargs) - return padding, dynamic - - -def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): - padding = kwargs.pop('padding', '') - kwargs.setdefault('bias', False) - padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) - if is_dynamic: - return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) - else: - return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) - - -class MixedConv2d(nn.ModuleDict): - """ Mixed Grouped Convolution - Based on MDConv and GroupedConv in MixNet impl: - https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py - """ - def __init__(self, in_channels, out_channels, kernel_size=3, - stride=1, padding='', dilation=1, depthwise=False, **kwargs): - super(MixedConv2d, self).__init__() - - kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] - num_groups = len(kernel_size) - in_splits = _split_channels(in_channels, num_groups) - out_splits = _split_channels(out_channels, num_groups) - self.in_channels = sum(in_splits) - self.out_channels = sum(out_splits) - for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): - conv_groups = out_ch if depthwise else 1 - # use add_module to keep key space clean - self.add_module( - str(idx), - create_conv2d_pad( - in_ch, out_ch, k, stride=stride, - padding=padding, dilation=dilation, groups=conv_groups, **kwargs) - ) - self.splits = in_splits - - def forward(self, x): - x_split = torch.split(x, self.splits, 1) - x_out = [c(x_split[i]) for i, c in enumerate(self.values())] - x = torch.cat(x_out, 1) - return x - - -def get_condconv_initializer(initializer, num_experts, expert_shape): - def condconv_initializer(weight): - """CondConv initializer function.""" - num_params = np.prod(expert_shape) - if (len(weight.shape) != 2 or weight.shape[0] != num_experts or - weight.shape[1] != num_params): - raise (ValueError( - 'CondConv variables must have shape [num_experts, num_params]')) - for i in range(num_experts): - initializer(weight[i].view(expert_shape)) - return condconv_initializer - - -class CondConv2d(nn.Module): - """ Conditional Convolution - Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py - - Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: - https://github.com/pytorch/pytorch/issues/17983 - """ - __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] - - def __init__(self, in_channels, out_channels, kernel_size=3, - stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): - super(CondConv2d, self).__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = _pair(kernel_size) - self.stride = _pair(stride) - padding_val, is_padding_dynamic = get_padding_value( - padding, kernel_size, stride=stride, dilation=dilation) - self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript - self.padding = _pair(padding_val) - self.dilation = _pair(dilation) - self.groups = groups - self.num_experts = num_experts - - self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size - weight_num_param = 1 - for wd in self.weight_shape: - weight_num_param *= wd - self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) - - if bias: - self.bias_shape = (self.out_channels,) - self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) - else: - self.register_parameter('bias', None) - - self.reset_parameters() - - def reset_parameters(self): - init_weight = get_condconv_initializer( - partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) - init_weight(self.weight) - if self.bias is not None: - fan_in = np.prod(self.weight_shape[1:]) - bound = 1 / math.sqrt(fan_in) - init_bias = get_condconv_initializer( - partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) - init_bias(self.bias) - - def forward(self, x, routing_weights): - B, C, H, W = x.shape - weight = torch.matmul(routing_weights, self.weight) - new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size - weight = weight.view(new_weight_shape) - bias = None - if self.bias is not None: - bias = torch.matmul(routing_weights, self.bias) - bias = bias.view(B * self.out_channels) - # move batch elements with channels so each batch element can be efficiently convolved with separate kernel - x = x.view(1, B * C, H, W) - if self.dynamic_padding: - out = conv2d_same( - x, weight, bias, stride=self.stride, padding=self.padding, - dilation=self.dilation, groups=self.groups * B) - else: - out = F.conv2d( - x, weight, bias, stride=self.stride, padding=self.padding, - dilation=self.dilation, groups=self.groups * B) - out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) - - # Literal port (from TF definition) - # x = torch.split(x, 1, 0) - # weight = torch.split(weight, 1, 0) - # if self.bias is not None: - # bias = torch.matmul(routing_weights, self.bias) - # bias = torch.split(bias, 1, 0) - # else: - # bias = [None] * B - # out = [] - # for xi, wi, bi in zip(x, weight, bias): - # wi = wi.view(*self.weight_shape) - # if bi is not None: - # bi = bi.view(*self.bias_shape) - # out.append(self.conv_fn( - # xi, wi, bi, stride=self.stride, padding=self.padding, - # dilation=self.dilation, groups=self.groups)) - # out = torch.cat(out, 0) - return out - - -class SelectiveKernelAttn(nn.Module): - def __init__(self, channels, num_paths=2, attn_channels=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelAttn, self).__init__() - self.num_paths = num_paths - self.pool = nn.AdaptiveAvgPool2d(1) - self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) - self.bn = norm_layer(attn_channels) - self.act = act_layer(inplace=True) - self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) - - def forward(self, x): - assert x.shape[1] == self.num_paths - x = torch.sum(x, dim=1) - x = self.pool(x) - x = self.fc_reduce(x) - x = self.bn(x) - x = self.act(x) - x = self.fc_select(x) - B, C, H, W = x.shape - x = x.view(B, self.num_paths, C // self.num_paths, H, W) - x = torch.softmax(x, dim=1) - return x - - -def _kernel_valid(k): - if isinstance(k, (list, tuple)): - for ki in k: - return _kernel_valid(ki) - assert k >= 3 and k % 2 - - -class ConvBnAct(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, - drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(ConvBnAct, self).__init__() - padding = _get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block - self.conv = nn.Conv2d( - in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, - padding=padding, dilation=dilation, groups=groups, bias=False) - self.bn = norm_layer(out_channels) - self.drop_block = drop_block - if act_layer is not None: - self.act = act_layer(inplace=True) - else: - self.act = None - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - if self.drop_block is not None: - x = self.drop_block(x) - if self.act is not None: - x = self.act(x) - return x - - -class SelectiveKernelConv(nn.Module): - - def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, - attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, - drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): - super(SelectiveKernelConv, self).__init__() - kernel_size = kernel_size or [3, 5] - _kernel_valid(kernel_size) - if not isinstance(kernel_size, list): - kernel_size = [kernel_size] * 2 - if keep_3x3: - dilation = [dilation * (k - 1) // 2 for k in kernel_size] - kernel_size = [3] * len(kernel_size) - else: - dilation = [dilation] * len(kernel_size) - self.num_paths = len(kernel_size) - self.in_channels = in_channels - self.out_channels = out_channels - self.split_input = split_input - if self.split_input: - assert in_channels % self.num_paths == 0 - in_channels = in_channels // self.num_paths - groups = min(out_channels, groups) - - conv_kwargs = dict( - stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) - self.paths = nn.ModuleList([ - ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) - for k, d in zip(kernel_size, dilation)]) - - attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) - self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) - self.drop_block = drop_block - - def forward(self, x): - if self.split_input: - x_split = torch.split(x, self.in_channels // self.num_paths, 1) - x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] - else: - x_paths = [op(x) for op in self.paths] - x = torch.stack(x_paths, dim=1) - x_attn = self.attn(x) - x = x * x_attn - x = torch.sum(x, dim=1) - return x - - -# helper method -def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): - assert 'groups' not in kwargs # only use 'depthwise' bool arg - if isinstance(kernel_size, list): - assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently - # We're going to use only lists for defining the MixedConv2d kernel groups, - # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. - m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) - else: - depthwise = kwargs.pop('depthwise', False) - groups = out_chs if depthwise else 1 - if 'num_experts' in kwargs and kwargs['num_experts'] > 0: - m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) - else: - m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) - return m diff --git a/timm/models/layers/conv2d_same.py b/timm/models/layers/conv2d_same.py new file mode 100644 index 00000000..579757b8 --- /dev/null +++ b/timm/models/layers/conv2d_same.py @@ -0,0 +1,79 @@ +""" Conv2d w/ Same Padding + +Hacked together by Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Union, List, Tuple, Optional, Callable +import math + +from .conv_helpers import get_padding + + +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +def _calc_same_pad(i: int, k: int, s: int, d: int): + return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + ih, iw = x.size()[-2:] + kh, kw = weight.size()[-2:] + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if _is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = get_padding(kernel_size, **kwargs) + else: + # dynamic 'SAME' padding, has runtime/GPU memory overhead + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py new file mode 100644 index 00000000..a10c1d38 --- /dev/null +++ b/timm/models/layers/conv_bn_act.py @@ -0,0 +1,32 @@ +""" Conv2d + BN + Act + +Hacked together by Ross Wightman +""" +from torch import nn as nn + +from timm.models.layers.conv_helpers import get_padding + + +class ConvBnAct(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(ConvBnAct, self).__init__() + padding = get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block + self.conv = nn.Conv2d( + in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=False) + self.bn = norm_layer(out_channels) + self.drop_block = drop_block + if act_layer is not None: + self.act = act_layer(inplace=True) + else: + self.act = None + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.drop_block is not None: + x = self.drop_block(x) + if self.act is not None: + x = self.act(x) + return x diff --git a/timm/models/layers/conv_helpers.py b/timm/models/layers/conv_helpers.py new file mode 100644 index 00000000..3f8b160e --- /dev/null +++ b/timm/models/layers/conv_helpers.py @@ -0,0 +1,27 @@ +""" Common Helpers + +Hacked together by Ross Wightman +""" +from itertools import repeat +from torch._six import container_abcs + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +tup_single = _ntuple(1) +tup_pair = _ntuple(2) +tup_triple = _ntuple(3) +tup_quadruple = _ntuple(4) + + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding diff --git a/timm/models/layers/nn_ops.py b/timm/models/layers/drop.py similarity index 100% rename from timm/models/layers/nn_ops.py rename to timm/models/layers/drop.py diff --git a/timm/models/layers/mixed_conv2d.py b/timm/models/layers/mixed_conv2d.py new file mode 100644 index 00000000..3e280c03 --- /dev/null +++ b/timm/models/layers/mixed_conv2d.py @@ -0,0 +1,49 @@ +""" Conditional Convolution + +Hacked together by Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv2d_same import create_conv2d_pad + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = out_ch if depthwise else 1 + # use add_module to keep key space clean + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [c(x_split[i]) for i, c in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x diff --git a/timm/models/layers/select_conv2d.py b/timm/models/layers/select_conv2d.py new file mode 100644 index 00000000..a8713b0b --- /dev/null +++ b/timm/models/layers/select_conv2d.py @@ -0,0 +1,30 @@ +""" Select Conv2d Factory Method + +Hacked together by Ross Wightman +""" + +from .mixed_conv2d import MixedConv2d +from .cond_conv2d import CondConv2d +from .conv2d_same import create_conv2d_pad + + +def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): + """ Select a 2d convolution implementation based on arguments + Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. + + Used extensively by EfficientNet, MobileNetv3 and related networks. + """ + assert 'groups' not in kwargs # only use 'depthwise' bool arg + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + groups = out_chs if depthwise else 1 + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + return m diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py new file mode 100644 index 00000000..4100aa02 --- /dev/null +++ b/timm/models/layers/selective_kernel.py @@ -0,0 +1,88 @@ +""" Selective Kernel Convolution Attention + +Hacked together by Ross Wightman +""" + +import torch +from torch import nn as nn + +from .conv_bn_act import ConvBnAct + + +def _kernel_valid(k): + if isinstance(k, (list, tuple)): + for ki in k: + return _kernel_valid(ki) + assert k >= 3 and k % 2 + + +class SelectiveKernelAttn(nn.Module): + def __init__(self, channels, num_paths=2, attn_channels=32, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelAttn, self).__init__() + self.num_paths = num_paths + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) + self.bn = norm_layer(attn_channels) + self.act = act_layer(inplace=True) + self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) + + def forward(self, x): + assert x.shape[1] == self.num_paths + x = torch.sum(x, dim=1) + x = self.pool(x) + x = self.fc_reduce(x) + x = self.bn(x) + x = self.act(x) + x = self.fc_select(x) + B, C, H, W = x.shape + x = x.view(B, self.num_paths, C // self.num_paths, H, W) + x = torch.softmax(x, dim=1) + return x + + +class SelectiveKernelConv(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, + attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelConv, self).__init__() + kernel_size = kernel_size or [3, 5] + _kernel_valid(kernel_size) + if not isinstance(kernel_size, list): + kernel_size = [kernel_size] * 2 + if keep_3x3: + dilation = [dilation * (k - 1) // 2 for k in kernel_size] + kernel_size = [3] * len(kernel_size) + else: + dilation = [dilation] * len(kernel_size) + self.num_paths = len(kernel_size) + self.in_channels = in_channels + self.out_channels = out_channels + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths + groups = min(out_channels, groups) + + conv_kwargs = dict( + stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) + self.paths = nn.ModuleList([ + ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) + + attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) + self.drop_block = drop_block + + def forward(self, x): + if self.split_input: + x_split = torch.split(x, self.in_channels // self.num_paths, 1) + x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] + else: + x_paths = [op(x) for op in self.paths] + x = torch.stack(x_paths, dim=1) + x_attn = self.attn(x) + x = x * x_attn + x = torch.sum(x, dim=1) + return x diff --git a/timm/models/layers/test_time_pool.py b/timm/models/layers/test_time_pool.py index ce6ddf07..33e24970 100644 --- a/timm/models/layers/test_time_pool.py +++ b/timm/models/layers/test_time_pool.py @@ -1,3 +1,8 @@ +""" Test Time Pooling (Average-Max Pool) + +Hacked together by Ross Wightman +""" + import logging from torch import nn import torch.nn.functional as F diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 76f6363c..9d4de856 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -7,8 +7,6 @@ Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 Hacked together by Ross Wightman """ -import torch.nn as nn -import torch.nn.functional as F from .efficientnet_builder import * from .registry import register_model From f902bcd54cd071fc120e1bcb20341d801f15ddd8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 10 Feb 2020 11:55:03 -0800 Subject: [PATCH 20/23] Layer refactoring continues, ResNet downsample rewrite for proper dilation in 3x3 and avg_pool cases * select_conv2d -> create_conv2d * added create_attn to create attention module from string/bool/module * factor padding helpers into own file, use in both conv2d_same and avg_pool2d_same * add some more test eca resnet variants * minor tweaks, naming, comments, consistency --- timm/models/efficientnet.py | 12 +- timm/models/efficientnet_blocks.py | 18 +- timm/models/gluon_resnet.py | 20 ++- timm/models/layers/__init__.py | 7 +- timm/models/layers/avg_pool2d_same.py | 31 ++++ timm/models/layers/cond_conv2d.py | 2 +- timm/models/layers/conv2d_same.py | 19 +- timm/models/layers/conv_bn_act.py | 2 +- timm/models/layers/create_attn.py | 30 ++++ .../{select_conv2d.py => create_conv2d.py} | 4 +- timm/models/layers/drop.py | 15 +- timm/models/layers/eca.py | 36 ++-- .../layers/{conv_helpers.py => helpers.py} | 10 +- timm/models/layers/padding.py | 33 ++++ timm/models/layers/se.py | 21 +++ timm/models/layers/test_time_pool.py | 2 + timm/models/mobilenetv3.py | 8 +- timm/models/res2net.py | 10 +- timm/models/resnet.py | 166 ++++++++++-------- timm/models/sknet.py | 28 ++- 20 files changed, 311 insertions(+), 163 deletions(-) create mode 100644 timm/models/layers/avg_pool2d_same.py create mode 100644 timm/models/layers/create_attn.py rename timm/models/layers/{select_conv2d.py => create_conv2d.py} (92%) rename timm/models/layers/{conv_helpers.py => helpers.py} (62%) create mode 100644 timm/models/layers/padding.py create mode 100644 timm/models/layers/se.py diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index c5dcacd3..ea71c873 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -28,7 +28,7 @@ from .feature_hooks import FeatureHooks from .registry import register_model from .helpers import load_pretrained from .layers import SelectAdaptivePool2d -from timm.models.layers import select_conv2d +from timm.models.layers import create_conv2d from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD @@ -220,7 +220,7 @@ class EfficientNet(nn.Module): def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None, - pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'): super(EfficientNet, self).__init__() norm_kwargs = norm_kwargs or {} @@ -232,21 +232,21 @@ class EfficientNet(nn.Module): # Stem stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) - self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) self._in_chs = stem_size # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( - channel_multiplier, channel_divisor, channel_min, 32, pad_type, act_layer, se_kwargs, + channel_multiplier, channel_divisor, channel_min, output_stride, pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate, verbose=_DEBUG) self.blocks = nn.Sequential(*builder(self._in_chs, block_args)) self.feature_info = builder.features self._in_chs = builder.in_chs # Head + Pooling - self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type) + self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type) self.bn2 = norm_layer(self.num_features, **norm_kwargs) self.act2 = act_layer(inplace=True) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -314,7 +314,7 @@ class EfficientNetFeatures(nn.Module): # Stem stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) - self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) self._in_chs = stem_size diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index a231fa31..c87c2237 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from torch.nn import functional as F from .layers.activations import sigmoid -from .layers import select_conv2d +from .layers import create_conv2d # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per @@ -129,7 +129,7 @@ class ConvBnAct(nn.Module): norm_layer=nn.BatchNorm2d, norm_kwargs=None): super(ConvBnAct, self).__init__() norm_kwargs = norm_kwargs or {} - self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) + self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) self.bn1 = norm_layer(out_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) @@ -162,7 +162,7 @@ class DepthwiseSeparableConv(nn.Module): self.has_pw_act = pw_act # activation after point-wise conv self.drop_connect_rate = drop_connect_rate - self.conv_dw = select_conv2d( + self.conv_dw = create_conv2d( in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) self.bn1 = norm_layer(in_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) @@ -174,7 +174,7 @@ class DepthwiseSeparableConv(nn.Module): else: self.se = None - self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() @@ -223,12 +223,12 @@ class InvertedResidual(nn.Module): self.drop_connect_rate = drop_connect_rate # Point-wise expansion - self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) # Depth-wise convolution - self.conv_dw = select_conv2d( + self.conv_dw = create_conv2d( mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True, **conv_kwargs) self.bn2 = norm_layer(mid_chs, **norm_kwargs) @@ -242,7 +242,7 @@ class InvertedResidual(nn.Module): self.se = None # Point-wise linear projection - self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) self.bn3 = norm_layer(out_chs, **norm_kwargs) def feature_module(self, location): @@ -356,7 +356,7 @@ class EdgeResidual(nn.Module): self.drop_connect_rate = drop_connect_rate # Expansion convolution - self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) + self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) self.bn1 = norm_layer(mid_chs, **norm_kwargs) self.act1 = act_layer(inplace=True) @@ -368,7 +368,7 @@ class EdgeResidual(nn.Module): self.se = None # Point-wise linear projection - self.conv_pwl = select_conv2d( + self.conv_pwl = create_conv2d( mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type) self.bn2 = norm_layer(out_chs, **norm_kwargs) diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index f835a485..6ccc4c53 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -11,6 +11,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained +from .layers import SEModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .resnet import ResNet, Bottleneck, BasicBlock @@ -319,8 +320,8 @@ def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kw """ default_cfg = default_cfgs['gluon_seresnext50_32x4d'] model = ResNet( - Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, + num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -333,8 +334,8 @@ def gluon_seresnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **k """ default_cfg = default_cfgs['gluon_seresnext101_32x4d'] model = ResNet( - Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, + num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer=SEModule), **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -346,9 +347,10 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k """Constructs a SEResNeXt-101-64x4d model. """ default_cfg = default_cfgs['gluon_seresnext101_64x4d'] + block_args = dict(attn_layer=SEModule) model = ResNet( - Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, + num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -360,10 +362,10 @@ def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs an SENet-154 model. """ default_cfg = default_cfgs['gluon_senet154'] + block_args = dict(attn_layer=SEModule) model = ResNet( - Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, use_se=True, - stem_type='deep', down_kernel_size=3, block_reduce_first=2, - num_classes=num_classes, in_chans=in_chans, **kwargs) + Bottleneck, [3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3, + block_reduce_first=2, num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 79aa9ac2..828c20b2 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -1,8 +1,13 @@ +from .padding import get_padding +from .avg_pool2d_same import AvgPool2dSame +from .conv2d_same import Conv2dSame from .conv_bn_act import ConvBnAct from .mixed_conv2d import MixedConv2d from .cond_conv2d import CondConv2d, get_condconv_initializer -from .select_conv2d import select_conv2d +from .create_conv2d import create_conv2d +from .create_attn import create_attn from .selective_kernel import SelectiveKernelConv +from .se import SEModule from .eca import EcaModule, CecaModule from .activations import * from .adaptive_avgmax_pool import \ diff --git a/timm/models/layers/avg_pool2d_same.py b/timm/models/layers/avg_pool2d_same.py new file mode 100644 index 00000000..33656e79 --- /dev/null +++ b/timm/models/layers/avg_pool2d_same.py @@ -0,0 +1,31 @@ +""" AvgPool2d w/ Same Padding + +Hacked together by Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List +import math + +from .helpers import tup_pair +from .padding import pad_same + + +def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), + ceil_mode: bool = False, count_include_pad: bool = True): + x = pad_same(x, kernel_size, stride) + return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + +class AvgPool2dSame(nn.AvgPool2d): + """ Tensorflow like 'SAME' wrapper for 2D average pooling + """ + def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): + kernel_size = tup_pair(kernel_size) + stride = tup_pair(stride) + super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) + + def forward(self, x): + return avg_pool2d_same( + x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py index d6cba889..a7a424a6 100644 --- a/timm/models/layers/cond_conv2d.py +++ b/timm/models/layers/cond_conv2d.py @@ -10,8 +10,8 @@ import torch from torch import nn as nn from torch.nn import functional as F +from .helpers import tup_pair from .conv2d_same import get_padding_value, conv2d_same -from .conv_helpers import tup_pair def get_condconv_initializer(initializer, num_experts, expert_shape): diff --git a/timm/models/layers/conv2d_same.py b/timm/models/layers/conv2d_same.py index 579757b8..0e29ae8c 100644 --- a/timm/models/layers/conv2d_same.py +++ b/timm/models/layers/conv2d_same.py @@ -8,26 +8,13 @@ import torch.nn.functional as F from typing import Union, List, Tuple, Optional, Callable import math -from .conv_helpers import get_padding - - -def _is_static_pad(kernel_size, stride=1, dilation=1, **_): - return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 - - -def _calc_same_pad(i: int, k: int, s: int, d: int): - return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) +from .padding import get_padding, pad_same, is_static_pad def conv2d_same( x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): - ih, iw = x.size()[-2:] - kh, kw = weight.size()[-2:] - pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) - pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) - if pad_h > 0 or pad_w > 0: - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + x = pad_same(x, weight.shape[-2:], stride, dilation) return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) @@ -51,7 +38,7 @@ def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: padding = padding.lower() if padding == 'same': # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact - if _is_static_pad(kernel_size, **kwargs): + if is_static_pad(kernel_size, **kwargs): # static case, no extra overhead padding = get_padding(kernel_size, **kwargs) else: diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py index a10c1d38..f5c94720 100644 --- a/timm/models/layers/conv_bn_act.py +++ b/timm/models/layers/conv_bn_act.py @@ -4,7 +4,7 @@ Hacked together by Ross Wightman """ from torch import nn as nn -from timm.models.layers.conv_helpers import get_padding +from timm.models.layers import get_padding class ConvBnAct(nn.Module): diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py new file mode 100644 index 00000000..c8aba217 --- /dev/null +++ b/timm/models/layers/create_attn.py @@ -0,0 +1,30 @@ +""" Select AttentionFactory Method + +Hacked together by Ross Wightman +""" +import torch +from .se import SEModule +from .eca import EcaModule, CecaModule + + +def create_attn(attn_type, channels, **kwargs): + module_cls = None + if attn_type is not None: + if isinstance(attn_type, str): + attn_type = attn_type.lower() + if attn_type == 'se': + module_cls = SEModule + elif attn_type == 'eca': + module_cls = EcaModule + elif attn_type == 'eca': + module_cls = CecaModule + else: + assert False, "Invalid attn module (%s)" % attn_type + elif isinstance(attn_type, bool): + if attn_type: + module_cls = SEModule + else: + module_cls = attn_type + if module_cls is not None: + return module_cls(channels, **kwargs) + return None diff --git a/timm/models/layers/select_conv2d.py b/timm/models/layers/create_conv2d.py similarity index 92% rename from timm/models/layers/select_conv2d.py rename to timm/models/layers/create_conv2d.py index a8713b0b..527c80a3 100644 --- a/timm/models/layers/select_conv2d.py +++ b/timm/models/layers/create_conv2d.py @@ -1,4 +1,4 @@ -""" Select Conv2d Factory Method +""" Create Conv2d Factory Method Hacked together by Ross Wightman """ @@ -8,7 +8,7 @@ from .cond_conv2d import CondConv2d from .conv2d_same import create_conv2d_pad -def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): +def create_conv2d(in_chs, out_chs, kernel_size, **kwargs): """ Select a 2d convolution implementation based on arguments Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index 30b98427..46d5d20b 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -1,3 +1,9 @@ +""" DropBlock, DropPath + +PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. + +Hacked together by Ross Wightman +""" import torch import torch.nn as nn import torch.nn.functional as F @@ -6,6 +12,8 @@ import math def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noise=False): + """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf + """ _, _, height, width = x.shape total_size = width * height clipped_block_size = min(block_size, min(width, height)) @@ -24,7 +32,7 @@ def drop_block_2d(x, drop_prob=0.1, block_size=7, gamma_scale=1.0, drop_with_noi block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float() block_mask = -F.max_pool2d( -block_mask, - kernel_size=clipped_block_size, # block_size, + kernel_size=clipped_block_size, # block_size, ??? stride=1, padding=clipped_block_size // 2) @@ -58,7 +66,8 @@ class DropBlock2d(nn.Module): def drop_path(x, drop_prob=0.): - """Drop paths (Stochastic Depth) per sample (when applied in residual blocks).""" + """Drop paths (Stochastic Depth) per sample (when applied in residual blocks). + """ keep_prob = 1 - drop_prob random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize @@ -67,6 +76,8 @@ def drop_path(x, drop_prob=0.): class DropPath(nn.ModuleDict): + """Drop paths (Stochastic Depth) per sample (when applied in residual blocks). + """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index fab205cb..5e64f649 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -47,19 +47,20 @@ class EcaModule(nn.Module): gamma, beta: when channel is given parameters of mapping function refer to original paper https://arxiv.org/pdf/1910.03151.pdf (default=None. if channel size not given, use k_size given for kernel size.) - k_size: Adaptive selection of kernel size (default=3) + kernel_size: Adaptive selection of kernel size (default=3) """ - def __init__(self, channel=None, k_size=3, gamma=2, beta=1): + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): super(EcaModule, self).__init__() - assert k_size % 2 == 1 + assert kernel_size % 2 == 1 + + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) - if channel is not None: - t = int(abs(math.log(channel, 2)+beta) / gamma) - k_size = t if t % 2 else t + 1 + print('florg', kernel_size) self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) - self.sigmoid = nn.Sigmoid() + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) def forward(self, x): # Feature descriptor on the global spatial information @@ -69,7 +70,7 @@ class EcaModule(nn.Module): # Two different branches of ECA module y = self.conv(y) # Multi-scale information fusion - y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) + y = y.view(x.shape[0], -1, 1, 1).sigmoid() return x * y.expand_as(x) @@ -93,22 +94,21 @@ class CecaModule(nn.Module): k_size: Adaptive selection of kernel size (default=3) """ - def __init__(self, channel=None, k_size=3, gamma=2, beta=1): + def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): super(CecaModule, self).__init__() - assert k_size % 2 == 1 + assert kernel_size % 2 == 1 - if channel is not None: - t = int(abs(math.log(channel, 2)+beta) / gamma) - k_size = t if t % 2 else t + 1 + if channels is not None: + t = int(abs(math.log(channels, 2) + beta) / gamma) + kernel_size = max(t if t % 2 else t + 1, 3) self.avg_pool = nn.AdaptiveAvgPool2d(1) #pytorch circular padding mode is buggy as of pytorch 1.4 #see https://github.com/pytorch/pytorch/pull/17240 #implement manual circular padding - self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=0, bias=False) - self.padding = (k_size - 1) // 2 - self.sigmoid = nn.Sigmoid() + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False) + self.padding = (kernel_size - 1) // 2 def forward(self, x): # Feature descriptor on the global spatial information @@ -121,6 +121,6 @@ class CecaModule(nn.Module): y = self.conv(y) # Multi-scale information fusion - y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) + y = y.view(x.shape[0], -1, 1, 1).sigmoid() return x * y.expand_as(x) diff --git a/timm/models/layers/conv_helpers.py b/timm/models/layers/helpers.py similarity index 62% rename from timm/models/layers/conv_helpers.py rename to timm/models/layers/helpers.py index 3f8b160e..967c2f4c 100644 --- a/timm/models/layers/conv_helpers.py +++ b/timm/models/layers/helpers.py @@ -1,4 +1,4 @@ -""" Common Helpers +""" Layer/Module Helpers Hacked together by Ross Wightman """ @@ -21,7 +21,7 @@ tup_triple = _ntuple(3) tup_quadruple = _ntuple(4) -# Calculate symmetric padding for a convolution -def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: - padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 - return padding + + + + diff --git a/timm/models/layers/padding.py b/timm/models/layers/padding.py new file mode 100644 index 00000000..b3653866 --- /dev/null +++ b/timm/models/layers/padding.py @@ -0,0 +1,33 @@ +""" Padding Helpers + +Hacked together by Ross Wightman +""" +import math +from typing import List + +import torch.nn.functional as F + + +# Calculate symmetric padding for a convolution +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution +def get_same_padding(x: int, k: int, s: int, d: int): + return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) + + +# Can SAME padding for given args be done statically? +def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +# Dynamically pad input x with 'SAME' padding for conv with specified args +def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1)): + ih, iw = x.size()[-2:] + pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return x diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py new file mode 100644 index 00000000..de87ccf5 --- /dev/null +++ b/timm/models/layers/se.py @@ -0,0 +1,21 @@ +from torch import nn as nn + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction=16, act_layer=nn.ReLU): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + reduction_channels = max(channels // reduction, 8) + self.fc1 = nn.Conv2d( + channels, reduction_channels, kernel_size=1, padding=0, bias=True) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d( + reduction_channels, channels, kernel_size=1, padding=0, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.fc1(x_se) + x_se = self.act(x_se) + x_se = self.fc2(x_se) + return x * x_se.sigmoid() diff --git a/timm/models/layers/test_time_pool.py b/timm/models/layers/test_time_pool.py index 33e24970..dcfc66ca 100644 --- a/timm/models/layers/test_time_pool.py +++ b/timm/models/layers/test_time_pool.py @@ -34,6 +34,8 @@ class TestTimePoolHead(nn.Module): def apply_test_time_pool(model, config, args): test_time_pool = False + if not hasattr(model, 'default_cfg') or not model.default_cfg: + return model, False if not args.no_test_pool and \ config['input_size'][-1] > model.default_cfg['input_size'][-1] and \ config['input_size'][-2] > model.default_cfg['input_size'][-2]: diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 9d4de856..c74f4224 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -11,7 +11,7 @@ Hacked together by Ross Wightman from .efficientnet_builder import * from .registry import register_model from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d, select_conv2d +from .layers import SelectAdaptivePool2d, create_conv2d from .layers.activations import HardSwish, hard_sigmoid from .feature_hooks import FeatureHooks from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD @@ -82,7 +82,7 @@ class MobileNetV3(nn.Module): # Stem stem_size = round_channels(stem_size, channel_multiplier) - self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) self._in_chs = stem_size @@ -97,7 +97,7 @@ class MobileNetV3(nn.Module): # Head + Pooling self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.conv_head = select_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.conv_head = create_conv2d(self._in_chs, self.num_features, 1, padding=pad_type, bias=head_bias) self.act2 = act_layer(inplace=True) # Classifier @@ -162,7 +162,7 @@ class MobileNetV3Features(nn.Module): # Stem stem_size = round_channels(stem_size, channel_multiplier) - self.conv_stem = select_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) + self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) self._in_chs = stem_size diff --git a/timm/models/res2net.py b/timm/models/res2net.py index b8d31b3e..134cf00d 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -8,10 +8,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from .resnet import ResNet, SEModule +from .resnet import ResNet from .registry import register_model from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d +from .layers import SEModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = [] @@ -53,8 +53,8 @@ class Bottle2neck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=26, scale=4, use_se=False, - act_layer=nn.ReLU, norm_layer=None, dilation=1, first_dilation=None, **_): + cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None, + act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_): super(Bottle2neck, self).__init__() self.scale = scale self.is_first = stride > 1 or downsample is not None @@ -82,7 +82,7 @@ class Bottle2neck(nn.Module): self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None + self.se = attn_layer(outplanes) if attn_layer is not None else None self.relu = act_layer(inplace=True) self.downsample = downsample diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 528d5790..5b020272 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -7,13 +7,12 @@ ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered ste """ import math -import torch import torch.nn as nn import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained -from .layers import EcaModule, SelectAdaptivePool2d, DropBlock2d, DropPath +from .layers import SelectAdaptivePool2d, DropBlock2d, DropPath, AvgPool2dSame, create_attn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -103,7 +102,8 @@ default_cfgs = { 'ecaresnext26tn_32x4d': _cfg( url='', interpolation='bicubic'), - + 'ecaresnet18': _cfg(), + 'ecaresnet50': _cfg(), } @@ -112,32 +112,12 @@ def get_padding(kernel_size, stride, dilation=1): return padding -class SEModule(nn.Module): - - def __init__(self, channels, reduction_channels): - super(SEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc1 = nn.Conv2d( - channels, reduction_channels, kernel_size=1, padding=0, bias=True) - self.relu = nn.ReLU(inplace=True) - self.fc2 = nn.Conv2d( - reduction_channels, channels, kernel_size=1, padding=0, bias=True) - - def forward(self, x): - x_se = self.avg_pool(x) - x_se = self.fc1(x_se) - x_se = self.relu(x_se) - x_se = self.fc2(x_se) - return x * x_se.sigmoid() - - class BasicBlock(nn.Module): - __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - drop_block=None, drop_path=None): + attn_layer=None, drop_block=None, drop_path=None): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' @@ -155,7 +135,7 @@ class BasicBlock(nn.Module): first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) self.bn2 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None + self.se = create_attn(attn_layer, outplanes) self.act2 = act_layer(inplace=True) self.downsample = downsample @@ -199,9 +179,9 @@ class Bottleneck(nn.Module): __constants__ = ['se', 'downsample'] # for pre 1.4 torchscript compat expansion = 4 - def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - drop_block=None, drop_path=None): + attn_layer=None, drop_block=None, drop_path=None): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) @@ -220,7 +200,7 @@ class Bottleneck(nn.Module): self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None + self.se = create_attn(attn_layer, outplanes) self.act3 = act_layer(inplace=True) self.downsample = downsample @@ -266,6 +246,37 @@ class Bottleneck(nn.Module): return x +def downsample_conv( + in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 + p = get_padding(kernel_size, stride, first_dilation) + + return nn.Sequential(*[ + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False), + norm_layer(out_channels) + ]) + + +def downsample_avg( + in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + if stride == 1 and dilation == 1: + pool = nn.Identity() + else: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + + return nn.Sequential(*[ + pool, + nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False), + norm_layer(out_channels) + ]) + + class ResNet(nn.Module): """ResNet / ResNeXt / SE-ResNeXt / SE-Net @@ -307,8 +318,6 @@ class ResNet(nn.Module): Number of classification classes. in_chans : int, default 3 Number of input (color) channels. - use_se : bool, default False - Enable Squeeze-Excitation module in blocks cardinality : int, default 1 Number of convolution groups for 3x3 conv in Bottleneck. base_width : int, default 64 @@ -337,7 +346,7 @@ class ResNet(nn.Module): global_pool : str, default 'avg' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' """ - def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, use_eca=False, + def __init__(self, block, layers, num_classes=1000, in_chans=3, cardinality=1, base_width=64, stem_width=64, stem_type='', block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0., @@ -385,14 +394,14 @@ class ResNet(nn.Module): dilations[2:4] = [2, 4] else: assert output_stride == 32 - llargs = list(zip(channels, layers, strides, dilations)) - lkwargs = dict( - use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, + layer_args = list(zip(channels, layers, strides, dilations)) + layer_kwargs = dict( + reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) - self.layer1 = self._make_layer(block, *llargs[0], **lkwargs) - self.layer2 = self._make_layer(block, *llargs[1], **lkwargs) - self.layer3 = self._make_layer(block, drop_block=db_3, *llargs[2], **lkwargs) - self.layer4 = self._make_layer(block, drop_block=db_4, *llargs[3], **lkwargs) + self.layer1 = self._make_layer(block, *layer_args[0], **layer_kwargs) + self.layer2 = self._make_layer(block, *layer_args[1], **layer_kwargs) + self.layer3 = self._make_layer(block, drop_block=db_3, *layer_args[2], **layer_kwargs) + self.layer4 = self._make_layer(block, drop_block=db_4, *layer_args[3], **layer_kwargs) # Head (Pooling and Classifier) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -411,31 +420,21 @@ class ResNet(nn.Module): m.zero_init_last_bn() def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, - use_se=False, use_eca=False,avg_down=False, down_kernel_size=1, **kwargs): - norm_layer = kwargs.get('norm_layer') + avg_down=False, down_kernel_size=1, **kwargs): downsample = None - down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size + first_dilation = 1 if dilation in (1, 2) else 2 if stride != 1 or self.inplanes != planes * block.expansion: - downsample_padding = get_padding(down_kernel_size, stride) - downsample_layers = [] - conv_stride = stride - if avg_down: - avg_stride = stride if dilation == 1 else 1 - conv_stride = 1 - downsample_layers = [nn.AvgPool2d(avg_stride, avg_stride, ceil_mode=True, count_include_pad=False)] - downsample_layers += [ - nn.Conv2d(self.inplanes, planes * block.expansion, down_kernel_size, - stride=conv_stride, padding=downsample_padding, bias=False), - norm_layer(planes * block.expansion)] - downsample = nn.Sequential(*downsample_layers) + downsample_args = dict( + in_channels=self.inplanes, out_channels=planes * block.expansion, kernel_size=down_kernel_size, + stride=stride, dilation=dilation, first_dilation=first_dilation, norm_layer=kwargs.get('norm_layer')) + downsample = downsample_avg(**downsample_args) if avg_down else downsample_conv(**downsample_args) - first_dilation = 1 if dilation in (1, 2) else 2 - bkwargs = dict( + block_kwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - dilation=dilation, use_se=use_se, **kwargs) - layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **bkwargs)] + dilation=dilation, **kwargs) + layers = [block(self.inplanes, planes, stride, downsample, first_dilation=first_dilation, **block_kwargs)] self.inplanes = planes * block.expansion - layers += [block(self.inplanes, planes, **bkwargs) for _ in range(1, blocks)] + layers += [block(self.inplanes, planes, **block_kwargs) for _ in range(1, blocks)] return nn.Sequential(*layers) @@ -936,9 +935,8 @@ def seresnext26d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs) """ default_cfg = default_cfgs['seresnext26d_32x4d'] model = ResNet( - Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, - stem_width=32, stem_type='deep', avg_down=True, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, stem_type='deep', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -954,8 +952,8 @@ def seresnext26t_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs) default_cfg = default_cfgs['seresnext26t_32x4d'] model = ResNet( Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, - stem_width=32, stem_type='deep_tiered', avg_down=True, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + stem_width=32, stem_type='deep_tiered', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -971,25 +969,55 @@ def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs default_cfg = default_cfgs['seresnext26tn_32x4d'] model = ResNet( Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, - stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=dict(attn_layer='se'), **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model + @register_model def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a eca-ResNeXt-26-TN model. + """Constructs an ECA-ResNeXt-26-TN model. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. this model replaces SE module with the ECA module """ default_cfg = default_cfgs['ecaresnext26tn_32x4d'] + block_args = dict(attn_layer='eca') model = ResNet( Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, - stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_eca=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) + stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, + num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def ecaresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ Constructs an ECA-ResNet-18 model. + """ + default_cfg = default_cfgs['ecaresnet18'] + block_args = dict(attn_layer='eca') + model = ResNet( + BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def ecaresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs an ECA-ResNet-50 model. + """ + default_cfg = default_cfgs['ecaresnet50'] + block_args = dict(attn_layer='eca') + model = ResNet( + Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, block_args=block_args, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 032b7d0b..6db37da5 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -4,8 +4,8 @@ from torch import nn as nn from .registry import register_model from .helpers import load_pretrained -from .layers import SelectiveKernelConv, ConvBnAct -from .resnet import ResNet, SEModule +from .layers import SelectiveKernelConv, ConvBnAct, create_attn +from .resnet import ResNet from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -33,8 +33,8 @@ class SelectiveKernelBasic(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, - use_se=False, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, - drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, + drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None): super(SelectiveKernelBasic, self).__init__() sk_kwargs = sk_kwargs or {} @@ -42,7 +42,7 @@ class SelectiveKernelBasic(nn.Module): assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first - out_planes = planes * self.expansion + outplanes = planes * self.expansion first_dilation = first_dilation or dilation _selective_first = True # FIXME temporary, for experiments @@ -51,14 +51,14 @@ class SelectiveKernelBasic(nn.Module): inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) conv_kwargs['act_layer'] = None self.conv2 = ConvBnAct( - first_planes, out_planes, kernel_size=3, dilation=dilation, **conv_kwargs) + first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs) else: self.conv1 = ConvBnAct( inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs) conv_kwargs['act_layer'] = None self.conv2 = SelectiveKernelConv( - first_planes, out_planes, dilation=dilation, **conv_kwargs, **sk_kwargs) - self.se = SEModule(out_planes, planes // 4) if use_se else None + first_planes, outplanes, dilation=dilation, **conv_kwargs, **sk_kwargs) + self.se = create_attn(attn_layer, outplanes) self.act = act_layer(inplace=True) self.downsample = downsample self.stride = stride @@ -88,17 +88,15 @@ class SelectiveKernelBottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, first_dilation=None, - drop_block=None, drop_path=None, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, + drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None): super(SelectiveKernelBottleneck, self).__init__() sk_kwargs = sk_kwargs or {} conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first - out_planes = planes * self.expansion + outplanes = planes * self.expansion first_dilation = first_dilation or dilation self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) @@ -106,8 +104,8 @@ class SelectiveKernelBottleneck(nn.Module): first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, **conv_kwargs, **sk_kwargs) conv_kwargs['act_layer'] = None - self.conv3 = ConvBnAct(width, out_planes, kernel_size=1, **conv_kwargs) - self.se = SEModule(out_planes, planes // 4) if use_se else None + self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs) + self.se = create_attn(attn_layer, outplanes) self.act = act_layer(inplace=True) self.downsample = downsample self.stride = stride From 2a7d256fd5103c06c6e9ab20bbf2aa9a1e4fd114 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 10 Feb 2020 11:59:36 -0800 Subject: [PATCH 21/23] Re-enable mem-efficient/jit activations after torchscript tests --- timm/models/layers/activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/layers/activations.py b/timm/models/layers/activations.py index 165b7951..6f8d2f89 100644 --- a/timm/models/layers/activations.py +++ b/timm/models/layers/activations.py @@ -12,7 +12,7 @@ from torch import nn as nn from torch.nn import functional as F -_USE_MEM_EFFICIENT_ISH = False +_USE_MEM_EFFICIENT_ISH = True if _USE_MEM_EFFICIENT_ISH: # This version reduces memory overhead of Swish during training by # recomputing torch.sigmoid(x) in backward instead of saving it. From d7259918703888bde3870a44aac62edbaab41f44 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 10 Feb 2020 16:21:33 -0800 Subject: [PATCH 22/23] Remove debug print from ECA module --- timm/models/layers/eca.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index 5e64f649..7ca5033d 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -57,8 +57,6 @@ class EcaModule(nn.Module): t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) - print('florg', kernel_size) - self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) @@ -86,12 +84,12 @@ class CecaModule(nn.Module): (parameter size, throughput,latency, etc) Args: - channel: Number of channels of the input feature map for use in adaptive kernel sizes + channels: Number of channels of the input feature map for use in adaptive kernel sizes for actual calculations according to channel. gamma, beta: when channel is given parameters of mapping function refer to original paper https://arxiv.org/pdf/1910.03151.pdf (default=None. if channel size not given, use k_size given for kernel size.) - k_size: Adaptive selection of kernel size (default=3) + kernel_size: Adaptive selection of kernel size (default=3) """ def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): From 5e6dbbaf30e1e6d027a1970425b0a63cb6f21cfe Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 10 Feb 2020 16:23:09 -0800 Subject: [PATCH 23/23] Add CBAM for experimentation --- timm/models/layers/cbam.py | 97 +++++++++++++++++++++++++++++++ timm/models/layers/create_attn.py | 5 ++ 2 files changed, 102 insertions(+) create mode 100644 timm/models/layers/cbam.py diff --git a/timm/models/layers/cbam.py b/timm/models/layers/cbam.py new file mode 100644 index 00000000..37ba1c35 --- /dev/null +++ b/timm/models/layers/cbam.py @@ -0,0 +1,97 @@ +""" CBAM (sort-of) Attention + +Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 + +Hacked together by Ross Wightman +""" + +import torch +from torch import nn as nn +from .conv_bn_act import ConvBnAct + + +class ChannelAttn(nn.Module): + """ Original CBAM channel attention module, currently avg + max pool variant only. + """ + def __init__(self, channels, reduction=16, act_layer=nn.ReLU): + super(ChannelAttn, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) + self.act = act_layer(inplace=True) + self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) + + def forward(self, x): + x_avg = self.avg_pool(x) + x_max = self.max_pool(x) + x_avg = self.fc2(self.act(self.fc1(x_avg))) + x_max = self.fc2(self.act(self.fc1(x_max))) + x_attn = x_avg + x_max + return x * x_attn.sigmoid() + + +class LightChannelAttn(ChannelAttn): + """An experimental 'lightweight' that sums avg + max pool first + """ + def __init__(self, channels, reduction=16): + super(LightChannelAttn, self).__init__(channels, reduction) + + def forward(self, x): + x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x) + x_attn = self.fc2(self.act(self.fc1(x_pool))) + return x * x_attn.sigmoid() + + +class SpatialAttn(nn.Module): + """ Original CBAM spatial attention module + """ + def __init__(self, kernel_size=7): + super(SpatialAttn, self).__init__() + self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) + + def forward(self, x): + x_avg = torch.mean(x, dim=1, keepdim=True) + x_max = torch.max(x, dim=1, keepdim=True)[0] + x_attn = torch.cat([x_avg, x_max], dim=1) + x_attn = self.conv(x_attn) + return x * x_attn.sigmoid() + + +class LightSpatialAttn(nn.Module): + """An experimental 'lightweight' variant that sums avg_pool and max_pool results. + """ + def __init__(self, kernel_size=7): + super(LightSpatialAttn, self).__init__() + self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) + + def forward(self, x): + x_avg = torch.mean(x, dim=1, keepdim=True) + x_max = torch.max(x, dim=1, keepdim=True)[0] + x_attn = 0.5 * x_avg + 0.5 * x_max + x_attn = self.conv(x_attn) + return x * x_attn.sigmoid() + + +class CbamModule(nn.Module): + def __init__(self, channels, spatial_kernel_size=7): + super(CbamModule, self).__init__() + self.channel = ChannelAttn(channels) + self.spatial = SpatialAttn(spatial_kernel_size) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + + +class LightCbamModule(nn.Module): + def __init__(self, channels, spatial_kernel_size=7): + super(LightCbamModule, self).__init__() + self.channel = LightChannelAttn(channels) + self.spatial = LightSpatialAttn(spatial_kernel_size) + + def forward(self, x): + x = self.channel(x) + x = self.spatial(x) + return x + diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index c8aba217..3bca254f 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -5,6 +5,7 @@ Hacked together by Ross Wightman import torch from .se import SEModule from .eca import EcaModule, CecaModule +from .cbam import CbamModule, LightCbamModule def create_attn(attn_type, channels, **kwargs): @@ -18,6 +19,10 @@ def create_attn(attn_type, channels, **kwargs): module_cls = EcaModule elif attn_type == 'eca': module_cls = CecaModule + elif attn_type == 'cbam': + module_cls = CbamModule + elif attn_type == 'lcbam': + module_cls = LightCbamModule else: assert False, "Invalid attn module (%s)" % attn_type elif isinstance(attn_type, bool):