One too many changes at a time, fix missing C

pull/101/head
Ross Wightman 4 years ago committed by Chris Ha
parent f17b42bc33
commit 1a9ab07307

@ -49,7 +49,7 @@ class BlurPool2d(nn.Module):
self.blur_filter = fn(self.blur_filter)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore
C = input_tensor.shape[1]
return F.conv2d(
self.padding(input_tensor),
self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1),
stride=self.stride, groups=input_tensor.shape[1])
self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), stride=self.stride, groups=C)

Loading…
Cancel
Save