SelectKernel split_input works best when input channels split like grouped conv, but output is full width. Disable zero_init for SK nets, seems a bad combo.

pull/87/head
Ross Wightman 5 years ago
parent 7d07ebb660
commit 13e8da2b46

@ -311,15 +311,13 @@ class SelectiveKernelConv(nn.Module):
kernel_size = [3] * len(kernel_size)
else:
dilation = [dilation] * len(kernel_size)
num_paths = len(kernel_size)
self.num_paths = num_paths
self.split_input = split_input
self.num_paths = len(kernel_size)
self.in_channels = in_channels
self.out_channels = out_channels
if split_input:
assert in_channels % num_paths == 0 and out_channels % num_paths == 0
in_channels = in_channels // num_paths
out_channels = out_channels // num_paths
self.split_input = split_input
if self.split_input:
assert in_channels % self.num_paths == 0
in_channels = in_channels // self.num_paths
groups = min(out_channels, groups)
conv_kwargs = dict(
@ -329,7 +327,7 @@ class SelectiveKernelConv(nn.Module):
for k, d in zip(kernel_size, dilation)])
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels)
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
self.drop_block = drop_block
def forward(self, x):
@ -338,15 +336,9 @@ class SelectiveKernelConv(nn.Module):
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
else:
x_paths = [op(x) for op in self.paths]
x = torch.stack(x_paths, dim=1)
x_attn = self.attn(x)
x = x * x_attn
if self.split_input:
B, N, C, H, W = x.shape
x = x.reshape(B, N * C, H, W)
else:
x = torch.sum(x, dim=1)
return x

@ -158,11 +158,12 @@ def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
default_cfg = default_cfgs['skresnet18']
sk_kwargs = dict(
min_attn_channels=16,
attn_reduction=8,
split_input=True
)
model = ResNet(
SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -179,7 +180,7 @@ def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
)
model = ResNet(
SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False
**kwargs)
model.default_cfg = default_cfg
if pretrained:
@ -199,7 +200,7 @@ def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
default_cfg = default_cfgs['skresnet50']
model = ResNet(
SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -218,7 +219,8 @@ def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
default_cfg = default_cfgs['skresnet50d']
model = ResNet(
SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), **kwargs)
num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
@ -233,7 +235,7 @@ def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
default_cfg = default_cfgs['skresnext50_32x4d']
model = ResNet(
SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)

Loading…
Cancel
Save