Use reshape instead of view in std_conv, causing issues in recent PyTorch in channels_last

pull/880/head
Ross Wightman 3 years ago
parent da06cc61d4
commit 515121cca1

@ -41,7 +41,7 @@ class StdConv2d(nn.Conv2d):
def forward(self, x): def forward(self, x):
weight = F.batch_norm( weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None, self.weight.reshape(1, self.out_channels, -1), None, None,
training=True, momentum=0., eps=self.eps).reshape_as(self.weight) training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x return x
@ -67,7 +67,7 @@ class StdConv2dSame(nn.Conv2d):
if self.same_pad: if self.same_pad:
x = pad_same(x, self.kernel_size, self.stride, self.dilation) x = pad_same(x, self.kernel_size, self.stride, self.dilation)
weight = F.batch_norm( weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None, self.weight.reshape(1, self.out_channels, -1), None, None,
training=True, momentum=0., eps=self.eps).reshape_as(self.weight) training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x return x
@ -96,7 +96,7 @@ class ScaledStdConv2d(nn.Conv2d):
def forward(self, x): def forward(self, x):
weight = F.batch_norm( weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None, self.weight.reshape(1, self.out_channels, -1), None, None,
weight=(self.gain * self.scale).view(-1), weight=(self.gain * self.scale).view(-1),
training=True, momentum=0., eps=self.eps).reshape_as(self.weight) training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
@ -127,7 +127,7 @@ class ScaledStdConv2dSame(nn.Conv2d):
if self.same_pad: if self.same_pad:
x = pad_same(x, self.kernel_size, self.stride, self.dilation) x = pad_same(x, self.kernel_size, self.stride, self.dilation)
weight = F.batch_norm( weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None, self.weight.reshape(1, self.out_channels, -1), None, None,
weight=(self.gain * self.scale).view(-1), weight=(self.gain * self.scale).view(-1),
training=True, momentum=0., eps=self.eps).reshape_as(self.weight) training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

Loading…
Cancel
Save