Blur filter no longer a buffer

pull/101/head
Ross Wightman 5 years ago committed by Chris Ha
parent 6cdeca24a3
commit f17b42bc33

@ -40,18 +40,16 @@ class BlurPool2d(nn.Module):
blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs
blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :])
# FIXME figure a clean hack to prevent the filter from getting saved in weights, but still
# plays nice with recursive module apply for fn like .cuda(), .type(), etc -RW
self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1)))
self.blur_filter = blur_filter[None, None, :, :]
def _apply(self, fn):
# override nn.Module _apply to prevent need for blur_filter to be registered as a buffer,
# this keeps it out of state dict, but allows .cuda(), .type(), etc to work as expected
super(BlurPool2d, self)._apply(fn)
self.blur_filter = fn(self.blur_filter)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
if not torch.is_tensor(input_tensor):
raise TypeError("Input input type is not a torch.Tensor. Got {}".format(type(input_tensor)))
if not len(input_tensor.shape) == 4:
raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}".format(input_tensor.shape))
# apply blur_filter on input
return F.conv2d(
self.padding(input_tensor),
self.blur_filter.type(input_tensor.dtype),
stride=self.stride,
groups=input_tensor.shape[1])
self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1),
stride=self.stride, groups=input_tensor.shape[1])

Loading…
Cancel
Save