diff --git a/timm/models/layers/blurpool.py b/timm/models/layers/blurpool.py index a12274d8..6df2b748 100644 --- a/timm/models/layers/blurpool.py +++ b/timm/models/layers/blurpool.py @@ -40,18 +40,16 @@ class BlurPool2d(nn.Module): blur_matrix = (np.poly1d((0.5, 0.5)) ** (blur_filter_size - 1)).coeffs blur_filter = torch.Tensor(blur_matrix[:, None] * blur_matrix[None, :]) - # FIXME figure a clean hack to prevent the filter from getting saved in weights, but still - # plays nice with recursive module apply for fn like .cuda(), .type(), etc -RW - self.register_buffer('blur_filter', blur_filter[None, None, :, :].repeat((self.channels, 1, 1, 1))) + self.blur_filter = blur_filter[None, None, :, :] + + def _apply(self, fn): + # override nn.Module _apply to prevent need for blur_filter to be registered as a buffer, + # this keeps it out of state dict, but allows .cuda(), .type(), etc to work as expected + super(BlurPool2d, self)._apply(fn) + self.blur_filter = fn(self.blur_filter) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore - if not torch.is_tensor(input_tensor): - raise TypeError("Input input type is not a torch.Tensor. Got {}".format(type(input_tensor))) - if not len(input_tensor.shape) == 4: - raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}".format(input_tensor.shape)) - # apply blur_filter on input return F.conv2d( self.padding(input_tensor), - self.blur_filter.type(input_tensor.dtype), - stride=self.stride, - groups=input_tensor.shape[1]) + self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), + stride=self.stride, groups=input_tensor.shape[1])