Fix #139. Broken SKResNets after BlurPool addition, as a plus, SKResNets support AA now too.

pull/140/head
Ross Wightman 4 years ago
parent 353a79aeba
commit 8d8677e03b

@ -9,13 +9,15 @@ from timm.models.layers import get_padding
class ConvBnAct(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, dilation=1, groups=1,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
super(ConvBnAct, self).__init__()
padding = get_padding(kernel_size, stride, dilation) # assuming PyTorch style padding for this block
use_aa = aa_layer is not None
self.conv = nn.Conv2d(
in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1 if use_aa else stride,
padding=padding, dilation=dilation, groups=groups, bias=False)
self.bn = norm_layer(out_channels)
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None
self.drop_block = drop_block
if act_layer is not None:
self.act = act_layer(inplace=True)
@ -29,4 +31,6 @@ class ConvBnAct(nn.Module):
x = self.drop_block(x)
if self.act is not None:
x = self.act(x)
if self.aa is not None:
x = self.aa(x)
return x

@ -52,7 +52,7 @@ class SelectiveKernelConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
""" Selective Kernel Convolution Module
As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
@ -98,7 +98,8 @@ class SelectiveKernelConv(nn.Module):
groups = min(out_channels, groups)
conv_kwargs = dict(
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer,
aa_layer=aa_layer)
self.paths = nn.ModuleList([
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)])

@ -46,12 +46,12 @@ class SelectiveKernelBasic(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(SelectiveKernelBasic, self).__init__()
sk_kwargs = sk_kwargs or {}
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
assert base_width == 64, 'BasicBlock doest not support changing base width'
first_planes = planes // reduce_first
@ -94,11 +94,12 @@ class SelectiveKernelBottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
drop_block=None, drop_path=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None):
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None,
drop_block=None, drop_path=None):
super(SelectiveKernelBottleneck, self).__init__()
sk_kwargs = sk_kwargs or {}
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer)
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)
width = int(math.floor(planes * (base_width / 64)) * cardinality)
first_planes = width // reduce_first
outplanes = planes * self.expansion

@ -1 +1 @@
__version__ = '0.1.24'
__version__ = '0.1.26'

Loading…
Cancel
Save