""" PyTorch Conditionally Parameterized Convolution (CondConv) Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference (https://arxiv.org/abs/1904.04971) Hacked together by / Copyright 2020 Ross Wightman """ import math from functools import partial import numpy as np import torch from torch import nn as nn from torch.nn import functional as F from .helpers import to_2tuple from .conv2d_same import conv2d_same from .padding import get_padding_value def get_condconv_initializer(initializer, num_experts, expert_shape): def condconv_initializer(weight): """CondConv initializer function.""" num_params = np.prod(expert_shape) if (len(weight.shape) != 2 or weight.shape[0] != num_experts or weight.shape[1] != num_params): raise (ValueError( 'CondConv variables must have shape [num_experts, num_params]')) for i in range(num_experts): initializer(weight[i].view(expert_shape)) return condconv_initializer class CondConv2d(nn.Module): """ Conditionally Parameterized Convolution Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: https://github.com/pytorch/pytorch/issues/17983 """ __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): super(CondConv2d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = to_2tuple(kernel_size) self.stride = to_2tuple(stride) padding_val, is_padding_dynamic = get_padding_value( padding, kernel_size, stride=stride, dilation=dilation) self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript self.padding = to_2tuple(padding_val) self.dilation = to_2tuple(dilation) self.groups = groups self.num_experts = num_experts self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size weight_num_param = 1 for wd in self.weight_shape: weight_num_param *= wd self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) if bias: self.bias_shape = (self.out_channels,) self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): init_weight = get_condconv_initializer( partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) init_weight(self.weight) if self.bias is not None: fan_in = np.prod(self.weight_shape[1:]) bound = 1 / math.sqrt(fan_in) init_bias = get_condconv_initializer( partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) init_bias(self.bias) def forward(self, x, routing_weights): B, C, H, W = x.shape weight = torch.matmul(routing_weights, self.weight) new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size weight = weight.view(new_weight_shape) bias = None if self.bias is not None: bias = torch.matmul(routing_weights, self.bias) bias = bias.view(B * self.out_channels) # move batch elements with channels so each batch element can be efficiently convolved with separate kernel # reshape instead of view to work with channels_last input x = x.reshape(1, B * C, H, W) if self.dynamic_padding: out = conv2d_same( x, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups * B) else: out = F.conv2d( x, weight, bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups * B) out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) # Literal port (from TF definition) # x = torch.split(x, 1, 0) # weight = torch.split(weight, 1, 0) # if self.bias is not None: # bias = torch.matmul(routing_weights, self.bias) # bias = torch.split(bias, 1, 0) # else: # bias = [None] * B # out = [] # for xi, wi, bi in zip(x, weight, bias): # wi = wi.view(*self.weight_shape) # if bi is not None: # bi = bi.view(*self.bias_shape) # out.append(self.conv_fn( # xi, wi, bi, stride=self.stride, padding=self.padding, # dilation=self.dilation, groups=self.groups)) # out = torch.cat(out, 0) return out