diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py index 3ccc16e1..d896ba5c 100644 --- a/timm/models/layers/std_conv.py +++ b/timm/models/layers/std_conv.py @@ -41,7 +41,7 @@ class StdConv2d(nn.Conv2d): def forward(self, x): weight = F.batch_norm( - self.weight.view(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0., eps=self.eps).reshape_as(self.weight) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -67,7 +67,7 @@ class StdConv2dSame(nn.Conv2d): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) weight = F.batch_norm( - self.weight.view(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), None, None, training=True, momentum=0., eps=self.eps).reshape_as(self.weight) x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x @@ -96,7 +96,7 @@ class ScaledStdConv2d(nn.Conv2d): def forward(self, x): weight = F.batch_norm( - self.weight.view(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), None, None, weight=(self.gain * self.scale).view(-1), training=True, momentum=0., eps=self.eps).reshape_as(self.weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) @@ -127,7 +127,7 @@ class ScaledStdConv2dSame(nn.Conv2d): if self.same_pad: x = pad_same(x, self.kernel_size, self.stride, self.dilation) weight = F.batch_norm( - self.weight.view(1, self.out_channels, -1), None, None, + self.weight.reshape(1, self.out_channels, -1), None, None, weight=(self.gain * self.scale).view(-1), training=True, momentum=0., eps=self.eps).reshape_as(self.weight) return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)