""" PyTorch Mixed Convolution Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) Hacked together by / Copyright 2020 Ross Wightman """ import torch from torch import nn as nn from .conv2d_same import create_conv2d_pad def _split_channels(num_chan, num_groups): split = [num_chan // num_groups for _ in range(num_groups)] split[0] += num_chan - sum(split) return split class MixedConv2d(nn.ModuleDict): """ Mixed Grouped Convolution Based on MDConv and GroupedConv in MixNet impl: https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', dilation=1, depthwise=False, **kwargs): super(MixedConv2d, self).__init__() kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] num_groups = len(kernel_size) in_splits = _split_channels(in_channels, num_groups) out_splits = _split_channels(out_channels, num_groups) self.in_channels = sum(in_splits) self.out_channels = sum(out_splits) for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): conv_groups = out_ch if depthwise else 1 # use add_module to keep key space clean self.add_module( str(idx), create_conv2d_pad( in_ch, out_ch, k, stride=stride, padding=padding, dilation=dilation, groups=conv_groups, **kwargs) ) self.splits = in_splits def forward(self, x): x_split = torch.split(x, self.splits, 1) x_out = [c(x_split[i]) for i, c in enumerate(self.values())] x = torch.cat(x_out, 1) return x