|
|
|
@ -130,7 +130,7 @@ class SpatialGate(nn.Module):
|
|
|
|
|
super(SpatialGate, self).__init__()
|
|
|
|
|
kernel_size = 7
|
|
|
|
|
self.compress = ChannelPool()
|
|
|
|
|
self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, bias=False)
|
|
|
|
|
self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2)
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x_compress = self.compress(x)
|
|
|
|
|
x_out = self.spatial(x_compress)
|
|
|
|
|