From 515121cca1545a3a8ac3c077579f970bcfce00da Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 23 Sep 2021 15:43:48 -0700 Subject: [PATCH] Use reshape instead of view in std_conv, causing issues in recent PyTorch in channels_last --- timm/models/layers/std_conv.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py index 3ccc16e1..d896ba5c 100644 --- a/timm/models/layers/std_conv.py +++ b/timm/models/layers/std_conv.py @@ -41,7 +41,7 @@ class StdConv2d(nn.Conv2d): def forward(self, x): weight = F.batch_norm( - self.weight.view(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0., eps=self.eps).reshape_as(self.weight) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -67,7 +67,7 @@ class StdConv2dSame(nn.Conv2d): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) weight = F.batch_norm( - self.weight.view(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0., eps=self.eps).reshape_as(self.weight) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -96,7 +96,7 @@ class ScaledStdConv2d(nn.Conv2d): def forward(self, x): weight = F.batch_norm( - self.weight.view(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), None, None, weight=(self.gain * self.scale).view(-1), training=True, momentum=0., eps=self.eps).reshape_as(self.weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) @@ -127,7 +127,7 @@ class ScaledStdConv2dSame(nn.Conv2d): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) weight = F.batch_norm( - self.weight.view(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), None, None, weight=(self.gain * self.scale).view(-1), training=True, momentum=0., eps=self.eps).reshape_as(self.weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)