|
|
@ -1,8 +1,9 @@
|
|
|
|
import math
|
|
|
|
""" Median Pool
|
|
|
|
import torch
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
|
|
|
|
"""
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from torch.nn.modules.utils import _pair, _quadruple
|
|
|
|
from .helpers import tup_pair, tup_quadruple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MedianPool2d(nn.Module):
|
|
|
|
class MedianPool2d(nn.Module):
|
|
|
@ -16,9 +17,9 @@ class MedianPool2d(nn.Module):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
|
|
|
|
def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
|
|
|
|
super(MedianPool2d, self).__init__()
|
|
|
|
super(MedianPool2d, self).__init__()
|
|
|
|
self.k = _pair(kernel_size)
|
|
|
|
self.k = tup_pair(kernel_size)
|
|
|
|
self.stride = _pair(stride)
|
|
|
|
self.stride = tup_pair(stride)
|
|
|
|
self.padding = _quadruple(padding) # convert to l, r, t, b
|
|
|
|
self.padding = tup_quadruple(padding) # convert to l, r, t, b
|
|
|
|
self.same = same
|
|
|
|
self.same = same
|
|
|
|
|
|
|
|
|
|
|
|
def _padding(self, x):
|
|
|
|
def _padding(self, x):
|
|
|
|