You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
52 lines
1.8 KiB
52 lines
1.8 KiB
5 years ago
|
""" PyTorch Mixed Convolution
|
||
|
|
||
|
Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
|
||
5 years ago
|
|
||
5 years ago
|
Hacked together by / Copyright 2020 Ross Wightman
|
||
5 years ago
|
"""
|
||
|
|
||
|
import torch
|
||
|
from torch import nn as nn
|
||
|
|
||
|
from .conv2d_same import create_conv2d_pad
|
||
|
|
||
|
|
||
|
def _split_channels(num_chan, num_groups):
|
||
|
split = [num_chan // num_groups for _ in range(num_groups)]
|
||
|
split[0] += num_chan - sum(split)
|
||
|
return split
|
||
|
|
||
|
|
||
|
class MixedConv2d(nn.ModuleDict):
|
||
|
""" Mixed Grouped Convolution
|
||
|
|
||
|
Based on MDConv and GroupedConv in MixNet impl:
|
||
|
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
||
|
"""
|
||
|
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||
|
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
|
||
|
super(MixedConv2d, self).__init__()
|
||
|
|
||
|
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
|
||
|
num_groups = len(kernel_size)
|
||
|
in_splits = _split_channels(in_channels, num_groups)
|
||
|
out_splits = _split_channels(out_channels, num_groups)
|
||
|
self.in_channels = sum(in_splits)
|
||
|
self.out_channels = sum(out_splits)
|
||
|
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
||
4 years ago
|
conv_groups = in_ch if depthwise else 1
|
||
5 years ago
|
# use add_module to keep key space clean
|
||
|
self.add_module(
|
||
|
str(idx),
|
||
|
create_conv2d_pad(
|
||
|
in_ch, out_ch, k, stride=stride,
|
||
|
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
|
||
|
)
|
||
|
self.splits = in_splits
|
||
|
|
||
|
def forward(self, x):
|
||
|
x_split = torch.split(x, self.splits, 1)
|
||
|
x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
|
||
|
x = torch.cat(x_out, 1)
|
||
|
return x
|