Use in_channels for depthwise groups, allows using `out_channels=N * in_channels` (does not impact existing models). Fix #354.

pull/419/head
Ross Wightman 4 years ago
parent 9811e229f7
commit 1bcc69e0ad

@ -22,7 +22,8 @@ def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
else: else:
depthwise = kwargs.pop('depthwise', False) depthwise = kwargs.pop('depthwise', False)
groups = out_channels if depthwise else kwargs.pop('groups', 1) # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
groups = in_channels if depthwise else kwargs.pop('groups', 1)
if 'num_experts' in kwargs and kwargs['num_experts'] > 0: if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
else: else:

@ -34,7 +34,7 @@ class MixedConv2d(nn.ModuleDict):
self.in_channels = sum(in_splits) self.in_channels = sum(in_splits)
self.out_channels = sum(out_splits) self.out_channels = sum(out_splits)
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
conv_groups = out_ch if depthwise else 1 conv_groups = in_ch if depthwise else 1
# use add_module to keep key space clean # use add_module to keep key space clean
self.add_module( self.add_module(
str(idx), str(idx),

Loading…
Cancel
Save