diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index ca9eaf16..afd1dcd7 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -264,9 +264,11 @@ class CrossStage(nn.Module): if self.conv_down is not None: x = self.conv_down(x) x = self.conv_exp(x) - xs, xb = x.chunk(2, dim=1) + split = x.shape[1] // 2 + xs, xb = x[:, :split], x[:, split:] xb = self.blocks(xb) - out = self.conv_transition(torch.cat([xs, self.conv_transition_b(xb)], dim=1)) + xb = self.conv_transition_b(xb).contiguous() + out = self.conv_transition(torch.cat([xs, xb], dim=1)) return out