import torch import torch.nn as nn import torch.nn.functional as F import math class Conv2dSame(nn.Conv2d): """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(Conv2dSame, self).__init__( in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) def forward(self, x): ih, iw = x.size()[-2:] kh, kw = self.weight.size()[-2:] oh = math.ceil(ih / self.stride[0]) ow = math.ceil(iw / self.stride[1]) pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: x = F.pad(x, [pad_w//2, pad_w - pad_w//2, pad_h//2, pad_h - pad_h//2]) return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) # helper method def sconv2d(in_chs, out_chs, kernel_size, **kwargs): padding = kwargs.pop('padding', 0) if isinstance(padding, str): if padding.lower() == 'same': return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) else: # 'valid' return nn.Conv2d(in_chs, out_chs, kernel_size, padding=0, **kwargs) else: return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)