From 13e8da2b46d8b48fa4bdc76dd89cd7aaf3f7d615 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 7 Feb 2020 22:42:04 -0800 Subject: [PATCH] SelectKernel split_input works best when input channels split like grouped conv, but output is full width. Disable zero_init for SK nets, seems a bad combo. --- timm/models/conv2d_layers.py | 22 +++++++--------------- timm/models/sknet.py | 12 +++++++----- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/timm/models/conv2d_layers.py b/timm/models/conv2d_layers.py index 5b1a44e8..feaf653c 100644 --- a/timm/models/conv2d_layers.py +++ b/timm/models/conv2d_layers.py @@ -311,15 +311,13 @@ class SelectiveKernelConv(nn.Module): kernel_size = [3] * len(kernel_size) else: dilation = [dilation] * len(kernel_size) - num_paths = len(kernel_size) - self.num_paths = num_paths - self.split_input = split_input + self.num_paths = len(kernel_size) self.in_channels = in_channels self.out_channels = out_channels - if split_input: - assert in_channels % num_paths == 0 and out_channels % num_paths == 0 - in_channels = in_channels // num_paths - out_channels = out_channels // num_paths + self.split_input = split_input + if self.split_input: + assert in_channels % self.num_paths == 0 + in_channels = in_channels // self.num_paths groups = min(out_channels, groups) conv_kwargs = dict( @@ -329,7 +327,7 @@ class SelectiveKernelConv(nn.Module): for k, d in zip(kernel_size, dilation)]) attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) - self.attn = SelectiveKernelAttn(out_channels, num_paths, attn_channels) + self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) self.drop_block = drop_block def forward(self, x): @@ -338,16 +336,10 @@ class SelectiveKernelConv(nn.Module): x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)] else: x_paths = [op(x) for op in self.paths] - x = torch.stack(x_paths, dim=1) x_attn = self.attn(x) x = x * x_attn - - if self.split_input: - B, N, C, H, W = x.shape - x = x.reshape(B, N * C, H, W) - else: - x = torch.sum(x, dim=1) + x = torch.sum(x, dim=1) return x diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 0c387e39..4b02d501 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -158,11 +158,12 @@ def sksresnet18(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['skresnet18'] sk_kwargs = dict( min_attn_channels=16, + attn_reduction=8, split_input=True ) model = ResNet( SelectiveKernelBasic, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, - block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -179,7 +180,7 @@ def skresnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): ) model = ResNet( SelectiveKernelBottleneck, [2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False **kwargs) model.default_cfg = default_cfg if pretrained: @@ -199,7 +200,7 @@ def skresnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['skresnet50'] model = ResNet( SelectiveKernelBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, - block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -218,7 +219,8 @@ def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['skresnet50d'] model = ResNet( SelectiveKernelBottleneck, [3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, - num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), **kwargs) + num_classes=num_classes, in_chans=in_chans, block_args=dict(sk_kwargs=sk_kwargs), + zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) @@ -233,7 +235,7 @@ def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): default_cfg = default_cfgs['skresnext50_32x4d'] model = ResNet( SelectiveKernelBottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) + num_classes=num_classes, in_chans=in_chans, zero_init_last_bn=False, **kwargs) model.default_cfg = default_cfg if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans)