From 1bcc69e0ad8adbd3c31202394415e4bfdbcc62d0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 9 Feb 2021 16:00:19 -0800 Subject: [PATCH] Use in_channels for depthwise groups, allows using `out_channels=N * in_channels` (does not impact existing models). Fix #354. --- timm/models/layers/create_conv2d.py | 3 ++- timm/models/layers/mixed_conv2d.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/timm/models/layers/create_conv2d.py b/timm/models/layers/create_conv2d.py index 0134b05c..3a0cc03a 100644 --- a/timm/models/layers/create_conv2d.py +++ b/timm/models/layers/create_conv2d.py @@ -22,7 +22,8 @@ def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) else: depthwise = kwargs.pop('depthwise', False) - groups = out_channels if depthwise else kwargs.pop('groups', 1) + # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 + groups = in_channels if depthwise else kwargs.pop('groups', 1) if 'num_experts' in kwargs and kwargs['num_experts'] > 0: m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) else: diff --git a/timm/models/layers/mixed_conv2d.py b/timm/models/layers/mixed_conv2d.py index 53d650cd..fa0ce565 100644 --- a/timm/models/layers/mixed_conv2d.py +++ b/timm/models/layers/mixed_conv2d.py @@ -34,7 +34,7 @@ class MixedConv2d(nn.ModuleDict): self.in_channels = sum(in_splits) self.out_channels = sum(out_splits) for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): - conv_groups = out_ch if depthwise else 1 + conv_groups = in_ch if depthwise else 1 # use add_module to keep key space clean self.add_module( str(idx),