From 9f11b4e8a25495874d84a56d4ca11af191a01324 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 29 Jan 2020 13:01:35 -0800 Subject: [PATCH] Add ConvBnAct layer to parallel integrated SelectKernelConv, add support for DropPath and DropBlock to ResNet base and SK blocks --- timm/models/conv2d_layers.py | 43 +++++++++---- timm/models/resnet.py | 20 +++--- timm/models/sknet.py | 118 +++++++++++++++-------------------- 3 files changed, 96 insertions(+), 85 deletions(-) diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py index 7583263a..5b1a44e8 100644 --- a/timm/models/conv2d_layers.py +++ b/timm/models/conv2d_layers.py @@ -271,11 +271,36 @@ def _kernel_valid(k): assert k >= 3 and k % 2 +class ConvBnAct(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1, + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(ConvBnAct, self).__init__() + padding = _get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block + self.conv = nn.Conv2d( + in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=False) + self.bn = norm_layer(out_channels) + self.drop_block = drop_block + if act_layer is not None: + self.act = act_layer(inplace=True) + else: + self.act = None + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.drop_block is not None: + x = self.drop_block(x) + if self.act is not None: + x = self.act(x) + return x + + class SelectiveKernelConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelConv, self).__init__() kernel_size = kernel_size or [3, 5] _kernel_valid(kernel_size) @@ -297,19 +322,15 @@ class SelectiveKernelConv(nn.Module): out_channels = out_channels // num_paths groups = min(out_channels, groups) - self.paths = nn.ModuleList() - for k, d in zip(kernel_size, dilation): - p = _get_padding(k, stride, d) - self.paths.append(nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d( - in_channels, out_channels, kernel_size=k, stride=stride, padding=p, - dilation=d, groups=groups, bias=False)), - ('bn', norm_layer(out_channels)), - ('act', act_layer(inplace=True)) - ]))) + conv_kwargs = dict( + stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) + self.paths = nn.ModuleList([ + ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + for k, d in zip(kernel_size, dilation)]) attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) + self.drop_block = drop_block def forward(self, x): if self.split_input: diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 976d4234..3e0ce23e 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .nn_ops import DropBlock2d, DropPath from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -132,7 +133,8 @@ class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_block=None, drop_path=None): super(BasicBlock, self).__init__() assert cardinality == 1, 'BasicBlock only supports cardinality of 1' @@ -181,7 +183,8 @@ class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_block=None, drop_path=None): super(Bottleneck, self).__init__() width = int(math.floor(planes * (base_width / 64)) * cardinality) @@ -305,8 +308,8 @@ class ResNet(nn.Module): def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, cardinality=1, base_width=64, stem_width=64, stem_type='', block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', - zero_init_last_bn=True, block_args=None): + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, drop_path_rate=0., + drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): block_args = block_args or dict() self.num_classes = num_classes deep_stem = 'deep' in stem_type @@ -338,6 +341,9 @@ class ResNet(nn.Module): self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) # Feature Blocks + dp = DropPath(drop_path_rate) if drop_block_rate else None + db_3 = DropBlock2d(drop_block_rate, 7, 0.25) if drop_block_rate else None + db_4 = DropBlock2d(drop_block_rate, 7, 1.00) if drop_block_rate else None channels, strides, dilations = [64, 128, 256, 512], [1, 2, 2, 2], [1] * 4 if output_stride == 16: strides[3] = 1 @@ -350,11 +356,11 @@ class ResNet(nn.Module): llargs = list(zip(channels, layers, strides, dilations)) lkwargs = dict( use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, - avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args) + avg_down=avg_down, down_kernel_size=down_kernel_size, drop_path=dp, **block_args) self.layer1 = self._make_layer(block, *llargs[0], **lkwargs) self.layer2 = self._make_layer(block, *llargs[1], **lkwargs) - self.layer3 = self._make_layer(block, *llargs[2], **lkwargs) - self.layer4 = self._make_layer(block, *llargs[3], **lkwargs) + self.layer3 = self._make_layer(block, drop_block=db_3, *llargs[2], **lkwargs) + self.layer4 = self._make_layer(block, drop_block=db_4, *llargs[3], **lkwargs) # Head (Pooling and Classifier) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) diff --git a/timm/models/sknet.py b/timm/models/sknet.py index e9dbf68d..41e19075 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -4,7 +4,7 @@ from torch import nn as nn from timm.models.registry import register_model from timm.models.helpers import load_pretrained -from timm.models.conv2d_layers import SelectiveKernelConv +from timm.models.conv2d_layers import SelectiveKernelConv, ConvBnAct from timm.models.resnet import ResNet, SEModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -29,61 +29,53 @@ default_cfgs = { class SelectiveKernelBasic(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, + use_se=False, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, + drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelBasic, self).__init__() sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first - outplanes = planes * self.expansion + out_planes = planes * self.expansion first_dilation = first_dilation or dilation _selective_first = True # FIXME temporary, for experiments if _selective_first: self.conv1 = SelectiveKernelConv( - inplanes, first_planes, stride=stride, dilation=first_dilation, **sk_kwargs) - else: - self.conv1 = nn.Conv2d( - inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation, - dilation=first_dilation, bias=False) - self.bn1 = norm_layer(first_planes) - self.act1 = act_layer(inplace=True) - if _selective_first: - self.conv2 = nn.Conv2d( - first_planes, outplanes, kernel_size=3, padding=dilation, - dilation=dilation, bias=False) + inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv2 = ConvBnAct( + first_planes, out_planes, kernel_size=3, dilation=dilation, **conv_kwargs) else: + self.conv1 = ConvBnAct( + inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs) + conv_kwargs['act_layer'] = None self.conv2 = SelectiveKernelConv( - first_planes, outplanes, dilation=dilation, **sk_kwargs) - self.bn2 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.act2 = act_layer(inplace=True) + first_planes, out_planes, dilation=dilation, **conv_kwargs, **sk_kwargs) + self.se = SEModule(out_planes, planes // 4) if use_se else None + self.act = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path def forward(self, x): residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.bn2(out) - + x = self.conv1(x) + x = self.conv2(x) if self.se is not None: - out = self.se(out) - + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.act2(out) - - return out + residual = self.downsample(residual) + x += residual + x = self.act(x) + return x class SelectiveKernelBottleneck(nn.Module): @@ -91,54 +83,46 @@ class SelectiveKernelBottleneck(nn.Module): def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, use_se=False, sk_kwargs=None, - reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + reduce_first=1, dilation=1, first_dilation=None, + drop_block=None, drop_path=None, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelBottleneck, self).__init__() sk_kwargs = sk_kwargs or {} + conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer) width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first - outplanes = planes * self.expansion + out_planes = planes * self.expansion first_dilation = first_dilation or dilation - self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) - self.bn1 = norm_layer(first_planes) - self.act1 = act_layer(inplace=True) + self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) self.conv2 = SelectiveKernelConv( - first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, **sk_kwargs) - self.bn2 = norm_layer(width) - self.act2 = act_layer(inplace=True) - self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) - self.bn3 = norm_layer(outplanes) - self.se = SEModule(outplanes, planes // 4) if use_se else None - self.act3 = act_layer(inplace=True) + first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, + **conv_kwargs, **sk_kwargs) + conv_kwargs['act_layer'] = None + self.conv3 = ConvBnAct(width, out_planes, kernel_size=1, **conv_kwargs) + self.se = SEModule(out_planes, planes // 4) if use_se else None + self.act = act_layer(inplace=True) self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path def forward(self, x): residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.act2(out) - - out = self.conv3(out) - out = self.bn3(out) - + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) if self.se is not None: - out = self.se(out) - + x = self.se(x) + if self.drop_path is not None: + x = self.drop_path(x) if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.act3(out) - - return out + residual = self.downsample(residual) + x += residual + x = self.act(x) + return x @register_model