You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
49 lines
1.7 KiB
49 lines
1.7 KiB
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn.modules.utils import _pair, _quadruple
|
|
|
|
|
|
class MedianPool2d(nn.Module):
|
|
""" Median pool (usable as median filter when stride=1) module.
|
|
|
|
Args:
|
|
kernel_size: size of pooling kernel, int or 2-tuple
|
|
stride: pool stride, int or 2-tuple
|
|
padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
|
|
same: override padding and enforce same padding, boolean
|
|
"""
|
|
def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
|
|
super(MedianPool2d, self).__init__()
|
|
self.k = _pair(kernel_size)
|
|
self.stride = _pair(stride)
|
|
self.padding = _quadruple(padding) # convert to l, r, t, b
|
|
self.same = same
|
|
|
|
def _padding(self, x):
|
|
if self.same:
|
|
ih, iw = x.size()[2:]
|
|
if ih % self.stride[0] == 0:
|
|
ph = max(self.k[0] - self.stride[0], 0)
|
|
else:
|
|
ph = max(self.k[0] - (ih % self.stride[0]), 0)
|
|
if iw % self.stride[1] == 0:
|
|
pw = max(self.k[1] - self.stride[1], 0)
|
|
else:
|
|
pw = max(self.k[1] - (iw % self.stride[1]), 0)
|
|
pl = pw // 2
|
|
pr = pw - pl
|
|
pt = ph // 2
|
|
pb = ph - pt
|
|
padding = (pl, pr, pt, pb)
|
|
else:
|
|
padding = self.padding
|
|
return padding
|
|
|
|
def forward(self, x):
|
|
x = F.pad(x, self._padding(x), mode='reflect')
|
|
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
|
|
x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
|
|
return x
|