|
|
|
@ -40,18 +40,16 @@ class BlurPool2d(nn.Module):
|
|
|
|
|
|
|
|
|
|
blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs
|
|
|
|
|
blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :])
|
|
|
|
|
# FIXME figure a clean hack to prevent the filter from getting saved in weights, but still
|
|
|
|
|
# plays nice with recursive module apply for fn like .cuda(), .type(), etc -RW
|
|
|
|
|
self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
|
|
|
|
self.blur_filter = blur_filter[None, None, :, :]
|
|
|
|
|
|
|
|
|
|
def _apply(self, fn):
|
|
|
|
|
# override nn.Module _apply to prevent need for blur_filter to be registered as a buffer,
|
|
|
|
|
# this keeps it out of state dict, but allows .cuda(), .type(), etc to work as expected
|
|
|
|
|
super(BlurPool2d, self)._apply(fn)
|
|
|
|
|
self.blur_filter = fn(self.blur_filter)
|
|
|
|
|
|
|
|
|
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
|
|
|
|
|
if not torch.is_tensor(input_tensor):
|
|
|
|
|
raise TypeError("Input input type is not a torch.Tensor. Got {}".format(type(input_tensor)))
|
|
|
|
|
if not len(input_tensor.shape) == 4:
|
|
|
|
|
raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}".format(input_tensor.shape))
|
|
|
|
|
# apply blur_filter on input
|
|
|
|
|
return F.conv2d(
|
|
|
|
|
self.padding(input_tensor),
|
|
|
|
|
self.blur_filter.type(input_tensor.dtype),
|
|
|
|
|
stride=self.stride,
|
|
|
|
|
groups=input_tensor.shape[1])
|
|
|
|
|
self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1),
|
|
|
|
|
stride=self.stride, groups=input_tensor.shape[1])
|
|
|
|
|