|
|
@ -41,7 +41,7 @@ class StdConv2d(nn.Conv2d):
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
weight = F.batch_norm(
|
|
|
|
weight = F.batch_norm(
|
|
|
|
self.weight.view(1, self.out_channels, -1), None, None,
|
|
|
|
self.weight.reshape(1, self.out_channels, -1), None, None,
|
|
|
|
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
|
|
|
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
|
|
|
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
@ -67,7 +67,7 @@ class StdConv2dSame(nn.Conv2d):
|
|
|
|
if self.same_pad:
|
|
|
|
if self.same_pad:
|
|
|
|
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
|
|
|
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
|
|
|
weight = F.batch_norm(
|
|
|
|
weight = F.batch_norm(
|
|
|
|
self.weight.view(1, self.out_channels, -1), None, None,
|
|
|
|
self.weight.reshape(1, self.out_channels, -1), None, None,
|
|
|
|
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
|
|
|
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
|
|
|
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
@ -96,7 +96,7 @@ class ScaledStdConv2d(nn.Conv2d):
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
weight = F.batch_norm(
|
|
|
|
weight = F.batch_norm(
|
|
|
|
self.weight.view(1, self.out_channels, -1), None, None,
|
|
|
|
self.weight.reshape(1, self.out_channels, -1), None, None,
|
|
|
|
weight=(self.gain * self.scale).view(-1),
|
|
|
|
weight=(self.gain * self.scale).view(-1),
|
|
|
|
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
|
|
|
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
|
|
|
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
|
@ -127,7 +127,7 @@ class ScaledStdConv2dSame(nn.Conv2d):
|
|
|
|
if self.same_pad:
|
|
|
|
if self.same_pad:
|
|
|
|
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
|
|
|
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
|
|
|
|
weight = F.batch_norm(
|
|
|
|
weight = F.batch_norm(
|
|
|
|
self.weight.view(1, self.out_channels, -1), None, None,
|
|
|
|
self.weight.reshape(1, self.out_channels, -1), None, None,
|
|
|
|
weight=(self.gain * self.scale).view(-1),
|
|
|
|
weight=(self.gain * self.scale).view(-1),
|
|
|
|
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
|
|
|
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
|
|
|
|
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|