You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
123 lines
5.0 KiB
123 lines
5.0 KiB
""" PyTorch Conditionally Parameterized Convolution (CondConv)
|
|
|
|
Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference
|
|
(https://arxiv.org/abs/1904.04971)
|
|
|
|
Hacked together by Ross Wightman
|
|
"""
|
|
|
|
import math
|
|
from functools import partial
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
from .helpers import tup_pair
|
|
from .conv2d_same import conv2d_same
|
|
from .padding import get_padding_value
|
|
|
|
|
|
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
|
def condconv_initializer(weight):
|
|
"""CondConv initializer function."""
|
|
num_params = np.prod(expert_shape)
|
|
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
|
|
weight.shape[1] != num_params):
|
|
raise (ValueError(
|
|
'CondConv variables must have shape [num_experts, num_params]'))
|
|
for i in range(num_experts):
|
|
initializer(weight[i].view(expert_shape))
|
|
return condconv_initializer
|
|
|
|
|
|
class CondConv2d(nn.Module):
|
|
""" Conditionally Parameterized Convolution
|
|
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
|
|
|
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
|
|
https://github.com/pytorch/pytorch/issues/17983
|
|
"""
|
|
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=3,
|
|
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
|
|
super(CondConv2d, self).__init__()
|
|
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = tup_pair(kernel_size)
|
|
self.stride = tup_pair(stride)
|
|
padding_val, is_padding_dynamic = get_padding_value(
|
|
padding, kernel_size, stride=stride, dilation=dilation)
|
|
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
|
|
self.padding = tup_pair(padding_val)
|
|
self.dilation = tup_pair(dilation)
|
|
self.groups = groups
|
|
self.num_experts = num_experts
|
|
|
|
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
|
weight_num_param = 1
|
|
for wd in self.weight_shape:
|
|
weight_num_param *= wd
|
|
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
|
|
|
|
if bias:
|
|
self.bias_shape = (self.out_channels,)
|
|
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
init_weight = get_condconv_initializer(
|
|
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
|
|
init_weight(self.weight)
|
|
if self.bias is not None:
|
|
fan_in = np.prod(self.weight_shape[1:])
|
|
bound = 1 / math.sqrt(fan_in)
|
|
init_bias = get_condconv_initializer(
|
|
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
|
|
init_bias(self.bias)
|
|
|
|
def forward(self, x, routing_weights):
|
|
B, C, H, W = x.shape
|
|
weight = torch.matmul(routing_weights, self.weight)
|
|
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
|
weight = weight.view(new_weight_shape)
|
|
bias = None
|
|
if self.bias is not None:
|
|
bias = torch.matmul(routing_weights, self.bias)
|
|
bias = bias.view(B * self.out_channels)
|
|
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
|
|
x = x.view(1, B * C, H, W)
|
|
if self.dynamic_padding:
|
|
out = conv2d_same(
|
|
x, weight, bias, stride=self.stride, padding=self.padding,
|
|
dilation=self.dilation, groups=self.groups * B)
|
|
else:
|
|
out = F.conv2d(
|
|
x, weight, bias, stride=self.stride, padding=self.padding,
|
|
dilation=self.dilation, groups=self.groups * B)
|
|
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
|
|
|
|
# Literal port (from TF definition)
|
|
# x = torch.split(x, 1, 0)
|
|
# weight = torch.split(weight, 1, 0)
|
|
# if self.bias is not None:
|
|
# bias = torch.matmul(routing_weights, self.bias)
|
|
# bias = torch.split(bias, 1, 0)
|
|
# else:
|
|
# bias = [None] * B
|
|
# out = []
|
|
# for xi, wi, bi in zip(x, weight, bias):
|
|
# wi = wi.view(*self.weight_shape)
|
|
# if bi is not None:
|
|
# bi = bi.view(*self.bias_shape)
|
|
# out.append(self.conv_fn(
|
|
# xi, wi, bi, stride=self.stride, padding=self.padding,
|
|
# dilation=self.dilation, groups=self.groups))
|
|
# out = torch.cat(out, 0)
|
|
return out
|