From a93bae6dc5a8831f1633208ac45d598a225810c5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 25 Jan 2020 18:31:08 -0800 Subject: [PATCH] A SelectiveKernelBasicBlock for more experiments --- timm/models/resnet.py | 61 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 60 insertions(+), 1 deletion(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 57a20894..1d64dcd9 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -265,7 +265,7 @@ class SelectiveKernelAttn(nn.Module): class SelectiveKernelConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=[3, 5], attn_reduction=16, - min_attn_feat=32, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True, + min_attn_feat=16, stride=1, dilation=1, groups=1, keep_3x3=True, use_attn=True, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SelectiveKernelConv, self).__init__() if not isinstance(kernel_size, list): @@ -316,6 +316,53 @@ class SelectiveKernelConv(nn.Module): return x +class SelectiveKernelBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, + cardinality=1, base_width=64, use_se=False, + reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(SelectiveKernelBasicBlock, self).__init__() + + assert cardinality == 1, 'BasicBlock only supports cardinality of 1' + assert base_width == 64, 'BasicBlock doest not support changing base width' + first_planes = planes // reduce_first + outplanes = planes * self.expansion + + self.conv1 = nn.Conv2d( + inplanes, first_planes, kernel_size=3, stride=stride, padding=dilation, + dilation=dilation, bias=False) + self.bn1 = norm_layer(first_planes) + self.act1 = act_layer(inplace=True) + self.conv2 = SelectiveKernelConv(first_planes, outplanes, dilation=previous_dilation) + self.bn2 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.act2 = act_layer(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.act1(out) + out = self.conv2(out) + out = self.bn2(out) + + if self.se is not None: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.act2(out) + + return out + + class SelectiveKernelBottleneck(nn.Module): expansion = 4 @@ -581,6 +628,18 @@ def resnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def skresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-18 model. + """ + default_cfg = default_cfgs['resnet18'] + model = ResNet(SelectiveKernelBasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNet-34 model.