|
|
|
@ -49,7 +49,7 @@ class BlurPool2d(nn.Module):
|
|
|
|
|
self.blur_filter = fn(self.blur_filter)
|
|
|
|
|
|
|
|
|
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
|
|
|
|
|
C = input_tensor.shape[1]
|
|
|
|
|
return F.conv2d(
|
|
|
|
|
self.padding(input_tensor),
|
|
|
|
|
self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1),
|
|
|
|
|
stride=self.stride, groups=input_tensor.shape[1])
|
|
|
|
|
self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), stride=self.stride, groups=C)
|
|
|
|
|