From 7ef7788ee935f3197cc843a972171c168c36b982 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 23 Feb 2021 13:15:52 -0800 Subject: [PATCH] Fix CUDA crash w/ channels-last + CSP models. Remove use of chunk() --- timm/models/cspnet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index ca9eaf16..afd1dcd7 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -264,9 +264,11 @@ class CrossStage(nn.Module): if self.conv_down is not None: x = self.conv_down(x) x = self.conv_exp(x) - xs, xb = x.chunk(2, dim=1) + split = x.shape[1] // 2 + xs, xb = x[:, :split], x[:, split:] xb = self.blocks(xb) - out = self.conv_transition(torch.cat([xs, self.conv_transition_b(xb)], dim=1)) + xb = self.conv_transition_b(xb).contiguous() + out = self.conv_transition(torch.cat([xs, xb], dim=1)) return out