You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
4.8 KiB
121 lines
4.8 KiB
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
|
|
|
|
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
|
|
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
|
|
|
|
|
def _get_padding(kernel_size, stride=1, dilation=1, **_):
|
|
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
|
return padding
|
|
|
|
|
|
def _calc_same_pad(i, k, s, d):
|
|
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
|
|
|
|
|
def _split_channels(num_chan, num_groups):
|
|
split = [num_chan // num_groups for _ in range(num_groups)]
|
|
split[0] += num_chan - sum(split)
|
|
return split
|
|
|
|
|
|
class Conv2dSame(nn.Conv2d):
|
|
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
|
"""
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
|
padding=0, dilation=1, groups=1, bias=True):
|
|
super(Conv2dSame, self).__init__(
|
|
in_channels, out_channels, kernel_size, stride, 0, dilation,
|
|
groups, bias)
|
|
|
|
def forward(self, x):
|
|
ih, iw = x.size()[-2:]
|
|
kh, kw = self.weight.size()[-2:]
|
|
pad_h = _calc_same_pad(ih, kh, self.stride[0], self.dilation[0])
|
|
pad_w = _calc_same_pad(iw, kw, self.stride[1], self.dilation[1])
|
|
if pad_h > 0 or pad_w > 0:
|
|
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
|
|
return F.conv2d(x, self.weight, self.bias, self.stride,
|
|
self.padding, self.dilation, self.groups)
|
|
|
|
|
|
def conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
|
padding = kwargs.pop('padding', '')
|
|
kwargs.setdefault('bias', False)
|
|
if isinstance(padding, str):
|
|
# for any string padding, the padding will be calculated for you, one of three ways
|
|
padding = padding.lower()
|
|
if padding == 'same':
|
|
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
|
if _is_static_pad(kernel_size, **kwargs):
|
|
# static case, no extra overhead
|
|
padding = _get_padding(kernel_size, **kwargs)
|
|
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
|
else:
|
|
# dynamic padding
|
|
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
|
elif padding == 'valid':
|
|
# 'VALID' padding, same as padding=0
|
|
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs)
|
|
else:
|
|
# Default to PyTorch style 'same'-ish symmetric padding
|
|
padding = _get_padding(kernel_size, **kwargs)
|
|
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
|
else:
|
|
# padding was specified as a number or pair
|
|
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
|
|
|
|
|
class MixedConv2d(nn.Module):
|
|
""" Mixed Grouped Convolution
|
|
Based on MDConv and GroupedConv in MixNet impl:
|
|
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=3,
|
|
stride=1, padding='', dilated=False, depthwise=False, **kwargs):
|
|
super(MixedConv2d, self).__init__()
|
|
|
|
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
|
|
num_groups = len(kernel_size)
|
|
in_splits = _split_channels(in_channels, num_groups)
|
|
out_splits = _split_channels(out_channels, num_groups)
|
|
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
|
d = 1
|
|
# FIXME make compat with non-square kernel/dilations/strides
|
|
if stride == 1 and dilated:
|
|
d, k = (k - 1) // 2, 3
|
|
conv_groups = out_ch if depthwise else 1
|
|
# use add_module to keep key space clean
|
|
self.add_module(
|
|
str(idx),
|
|
conv2d_pad(
|
|
in_ch, out_ch, k, stride=stride,
|
|
padding=padding, dilation=d, groups=conv_groups, **kwargs)
|
|
)
|
|
self.splits = in_splits
|
|
|
|
def forward(self, x):
|
|
x_split = torch.split(x, self.splits, 1)
|
|
x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
|
|
x = torch.cat(x_out, 1)
|
|
return x
|
|
|
|
|
|
# helper method
|
|
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
|
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
|
if isinstance(kernel_size, list):
|
|
# We're going to use only lists for defining the MixedConv2d kernel groups,
|
|
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
|
|
return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
|
|
else:
|
|
depthwise = kwargs.pop('depthwise', False)
|
|
groups = out_chs if depthwise else 1
|
|
return conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
|
|