pull/53/head
parent
db04677c94
commit
506df0e3d0
@ -1,120 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
|
||||
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
||||
|
||||
|
||||
def _get_padding(kernel_size, stride=1, dilation=1, **_):
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
return padding
|
||||
|
||||
|
||||
def _calc_same_pad(i, k, s, d):
|
||||
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
||||
|
||||
|
||||
def _split_channels(num_chan, num_groups):
|
||||
split = [num_chan // num_groups for _ in range(num_groups)]
|
||||
split[0] += num_chan - sum(split)
|
||||
return split
|
||||
|
||||
|
||||
class Conv2dSame(nn.Conv2d):
|
||||
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True):
|
||||
super(Conv2dSame, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, 0, dilation,
|
||||
groups, bias)
|
||||
|
||||
def forward(self, x):
|
||||
ih, iw = x.size()[-2:]
|
||||
kh, kw = self.weight.size()[-2:]
|
||||
pad_h = _calc_same_pad(ih, kh, self.stride[0], self.dilation[0])
|
||||
pad_w = _calc_same_pad(iw, kw, self.stride[1], self.dilation[1])
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2])
|
||||
return F.conv2d(x, self.weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
def conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
||||
padding = kwargs.pop('padding', '')
|
||||
kwargs.setdefault('bias', False)
|
||||
if isinstance(padding, str):
|
||||
# for any string padding, the padding will be calculated for you, one of three ways
|
||||
padding = padding.lower()
|
||||
if padding == 'same':
|
||||
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
||||
if _is_static_pad(kernel_size, **kwargs):
|
||||
# static case, no extra overhead
|
||||
padding = _get_padding(kernel_size, **kwargs)
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||
else:
|
||||
# dynamic padding
|
||||
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
||||
elif padding == 'valid':
|
||||
# 'VALID' padding, same as padding=0
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs)
|
||||
else:
|
||||
# Default to PyTorch style 'same'-ish symmetric padding
|
||||
padding = _get_padding(kernel_size, **kwargs)
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||
else:
|
||||
# padding was specified as a number or pair
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||
|
||||
|
||||
class MixedConv2d(nn.Module):
|
||||
""" Mixed Grouped Convolution
|
||||
Based on MDConv and GroupedConv in MixNet impl:
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||
stride=1, padding='', dilated=False, depthwise=False, **kwargs):
|
||||
super(MixedConv2d, self).__init__()
|
||||
|
||||
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
|
||||
num_groups = len(kernel_size)
|
||||
in_splits = _split_channels(in_channels, num_groups)
|
||||
out_splits = _split_channels(out_channels, num_groups)
|
||||
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
||||
d = 1
|
||||
# FIXME make compat with non-square kernel/dilations/strides
|
||||
if stride == 1 and dilated:
|
||||
d, k = (k - 1) // 2, 3
|
||||
conv_groups = out_ch if depthwise else 1
|
||||
# use add_module to keep key space clean
|
||||
self.add_module(
|
||||
str(idx),
|
||||
conv2d_pad(
|
||||
in_ch, out_ch, k, stride=stride,
|
||||
padding=padding, dilation=d, groups=conv_groups, **kwargs)
|
||||
)
|
||||
self.splits = in_splits
|
||||
|
||||
def forward(self, x):
|
||||
x_split = torch.split(x, self.splits, 1)
|
||||
x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
|
||||
x = torch.cat(x_out, 1)
|
||||
return x
|
||||
|
||||
|
||||
# helper method
|
||||
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
||||
if isinstance(kernel_size, list):
|
||||
# We're going to use only lists for defining the MixedConv2d kernel groups,
|
||||
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
|
||||
return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
depthwise = kwargs.pop('depthwise', False)
|
||||
groups = out_chs if depthwise else 1
|
||||
return conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||
|
@ -0,0 +1,255 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch._six import container_abcs
|
||||
from itertools import repeat
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, container_abcs.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
return parse
|
||||
|
||||
|
||||
_single = _ntuple(1)
|
||||
_pair = _ntuple(2)
|
||||
_triple = _ntuple(3)
|
||||
_quadruple = _ntuple(4)
|
||||
|
||||
|
||||
def _is_static_pad(kernel_size, stride=1, dilation=1, **_):
|
||||
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
||||
|
||||
|
||||
def _get_padding(kernel_size, stride=1, dilation=1, **_):
|
||||
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||
return padding
|
||||
|
||||
|
||||
def _calc_same_pad(i, k, s, d):
|
||||
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
|
||||
|
||||
|
||||
def _split_channels(num_chan, num_groups):
|
||||
split = [num_chan // num_groups for _ in range(num_groups)]
|
||||
split[0] += num_chan - sum(split)
|
||||
return split
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def conv2d_same(x, weight, bias=None, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1):
|
||||
ih, iw = x.size()[-2:]
|
||||
kh, kw = weight.size()[-2:]
|
||||
pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0])
|
||||
pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1])
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
||||
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
||||
|
||||
|
||||
class Conv2dSame(nn.Conv2d):
|
||||
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True):
|
||||
super(Conv2dSame, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
||||
|
||||
def forward(self, x):
|
||||
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
|
||||
def get_padding_value(padding, kernel_size, **kwargs):
|
||||
dynamic = False
|
||||
if isinstance(padding, str):
|
||||
# for any string padding, the padding will be calculated for you, one of three ways
|
||||
padding = padding.lower()
|
||||
if padding == 'same':
|
||||
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
||||
if _is_static_pad(kernel_size, **kwargs):
|
||||
# static case, no extra overhead
|
||||
padding = _get_padding(kernel_size, **kwargs)
|
||||
else:
|
||||
# dynamic padding
|
||||
padding = 0
|
||||
dynamic = True
|
||||
elif padding == 'valid':
|
||||
# 'VALID' padding, same as padding=0
|
||||
padding = 0
|
||||
else:
|
||||
# Default to PyTorch style 'same'-ish symmetric padding
|
||||
padding = _get_padding(kernel_size, **kwargs)
|
||||
return padding, dynamic
|
||||
|
||||
|
||||
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
||||
padding = kwargs.pop('padding', '')
|
||||
kwargs.setdefault('bias', False)
|
||||
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
|
||||
if is_dynamic:
|
||||
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||
|
||||
|
||||
class MixedConv2d(nn.Module):
|
||||
""" Mixed Grouped Convolution
|
||||
Based on MDConv and GroupedConv in MixNet impl:
|
||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||
stride=1, padding='', dilation=1, mixed_dilated=False, depthwise=False, **kwargs):
|
||||
super(MixedConv2d, self).__init__()
|
||||
|
||||
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
|
||||
num_groups = len(kernel_size)
|
||||
in_splits = _split_channels(in_channels, num_groups)
|
||||
out_splits = _split_channels(out_channels, num_groups)
|
||||
self.in_channels = sum(in_splits)
|
||||
self.out_channels = sum(out_splits)
|
||||
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
||||
d = dilation
|
||||
# FIXME make compat with non-square kernel/dilations/strides
|
||||
if stride == 1 and mixed_dilated:
|
||||
d, k = (k - 1) // 2, 3
|
||||
conv_groups = out_ch if depthwise else 1
|
||||
# use add_module to keep key space clean
|
||||
self.add_module(
|
||||
str(idx),
|
||||
create_conv2d_pad(
|
||||
in_ch, out_ch, k, stride=stride,
|
||||
padding=padding, dilation=d, groups=conv_groups, **kwargs)
|
||||
)
|
||||
self.splits = in_splits
|
||||
|
||||
def forward(self, x):
|
||||
x_split = torch.split(x, self.splits, 1)
|
||||
x_out = [c(x) for x, c in zip(x_split, self._modules.values())]
|
||||
x = torch.cat(x_out, 1)
|
||||
return x
|
||||
|
||||
|
||||
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
||||
def condconv_initializer(weight):
|
||||
"""CondConv initializer function."""
|
||||
num_params = np.prod(expert_shape)
|
||||
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
|
||||
weight.shape[1] != num_params):
|
||||
raise (ValueError(
|
||||
'CondConv variables must have shape [num_experts, num_params]'))
|
||||
for i in range(num_experts):
|
||||
initializer(weight[i].view(expert_shape))
|
||||
return condconv_initializer
|
||||
|
||||
|
||||
class CondConv2d(nn.Module):
|
||||
""" Conditional Convolution
|
||||
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
|
||||
super(CondConv2d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = _pair(kernel_size)
|
||||
self.stride = _pair(stride)
|
||||
padding_val, is_padding_dynamic = get_padding_value(
|
||||
padding, kernel_size, stride=stride, dilation=dilation)
|
||||
self.conv_fn = conv2d_same if is_padding_dynamic else F.conv2d
|
||||
self.padding = _pair(padding_val)
|
||||
self.dilation = _pair(dilation)
|
||||
self.transposed = False
|
||||
self.output_padding = _pair(0)
|
||||
self.groups = groups
|
||||
self.padding_mode = 'zero'
|
||||
self.num_experts = num_experts
|
||||
|
||||
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||
weight_num_param = 1
|
||||
for wd in self.weight_shape:
|
||||
weight_num_param *= wd
|
||||
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
|
||||
|
||||
# FIXME I haven't tested bias yet
|
||||
if bias:
|
||||
self.bias_shape = (self.out_channels,)
|
||||
condconv_bias_shape = (self.num_experts, self.out_channels)
|
||||
self.bias = torch.nn.Parameter(torch.Tensor(condconv_bias_shape))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.reset_parameters()
|
||||
# FIXME once I'm satisfied this works, remove the looping path?
|
||||
self._use_groups = True # use groups for parallel per-batch-element kernel convolution
|
||||
|
||||
def reset_parameters(self):
|
||||
init_weight = get_condconv_initializer(
|
||||
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
|
||||
init_weight(self.weight)
|
||||
if self.bias is not None:
|
||||
# FIXME bias not tested
|
||||
fan_in = np.prod(self.weight_shape[1:])
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
init_bias = get_condconv_initializer(
|
||||
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
|
||||
init_bias(self.bias)
|
||||
|
||||
def forward(self, x, routing_weights):
|
||||
weight = torch.matmul(routing_weights, self.weight)
|
||||
bias = torch.matmul(routing_weights, self.bias) if self.bias is not None else None
|
||||
B, C, H, W = x.shape
|
||||
if self._use_groups:
|
||||
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||
weight = weight.view(new_weight_shape)
|
||||
x = x.view(1, B * C, H, W)
|
||||
out = self.conv_fn(
|
||||
x, weight, bias, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups * B)
|
||||
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
|
||||
else:
|
||||
x = torch.split(x, 1, 0)
|
||||
weight = torch.split(weight, 1, 0)
|
||||
if self.bias is not None:
|
||||
bias = torch.matmul(routing_weights, self.bias)
|
||||
bias = torch.split(bias, 1, 0)
|
||||
else:
|
||||
bias = [None] * B
|
||||
out = []
|
||||
for xi, wi, bi in zip(x, weight, bias):
|
||||
wi = wi.view(*self.weight_shape)
|
||||
if bi is not None:
|
||||
bi = bi.view(*self.bias_shape)
|
||||
out.append(self.conv_fn(
|
||||
xi, wi, bi, stride=self.stride, padding=self.padding,
|
||||
dilation=self.dilation, groups=self.groups))
|
||||
out = torch.cat(out, 0)
|
||||
return out
|
||||
|
||||
|
||||
# helper method
|
||||
def select_conv2d(in_chs, out_chs, kernel_size, **kwargs):
|
||||
assert 'groups' not in kwargs # only use 'depthwise' bool arg
|
||||
if isinstance(kernel_size, list):
|
||||
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
|
||||
# We're going to use only lists for defining the MixedConv2d kernel groups,
|
||||
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
|
||||
return MixedConv2d(in_chs, out_chs, kernel_size, **kwargs)
|
||||
else:
|
||||
depthwise = kwargs.pop('depthwise', False)
|
||||
groups = out_chs if depthwise else 1
|
||||
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
|
||||
create_fn = CondConv2d
|
||||
else:
|
||||
create_fn = create_conv2d_pad
|
||||
return create_fn(in_chs, out_chs, kernel_size, groups=groups, **kwargs)
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue