diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 7ab8d659..462ad8fe 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -89,10 +89,11 @@ class LinearBottleneck(nn.Module): x = self.se(x) x = self.act_dw(x) x = self.conv_pwl(x) - if self.drop_path is not None: - x = self.drop_path(x) if self.use_shortcut: - x[:, 0:self.in_channels] += shortcut + if self.drop_path is not None: + x = self.drop_path(x) + + x[:, 0:self.in_channels] += shortcut return x