diff --git a/timm/models/layers/blurpool.py b/timm/models/layers/blurpool.py index 6df2b748..57af3e0e 100644 --- a/timm/models/layers/blurpool.py +++ b/timm/models/layers/blurpool.py @@ -49,7 +49,7 @@ class BlurPool2d(nn.Module): self.blur_filter = fn(self.blur_filter) def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # type: ignore + C = input_tensor.shape[1] return F.conv2d( self.padding(input_tensor), - self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), - stride=self.stride, groups=input_tensor.shape[1]) + self.blur_filter.type(input_tensor.dtype).expand(C, -1, -1, -1), stride=self.stride, groups=C)