|
|
|
@ -81,7 +81,7 @@ class BilinearAttnTransform(nn.Module):
|
|
|
|
|
self.groups = groups
|
|
|
|
|
self.in_channels = in_channels
|
|
|
|
|
|
|
|
|
|
def resize_mat(self, x, t):
|
|
|
|
|
def resize_mat(self, x, t: int):
|
|
|
|
|
B, C, block_size, block_size1 = x.shape
|
|
|
|
|
assert block_size == block_size1
|
|
|
|
|
if t <= 1:
|
|
|
|
@ -100,10 +100,8 @@ class BilinearAttnTransform(nn.Module):
|
|
|
|
|
out = self.conv1(x)
|
|
|
|
|
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
|
|
|
|
|
cp = F.adaptive_max_pool2d(out, (1, self.block_size))
|
|
|
|
|
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size)
|
|
|
|
|
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size)
|
|
|
|
|
p = F.sigmoid(p)
|
|
|
|
|
q = F.sigmoid(q)
|
|
|
|
|
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
|
|
|
|
|
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
|
|
|
|
|
p = p / p.sum(dim=3, keepdim=True)
|
|
|
|
|
q = q / q.sum(dim=2, keepdim=True)
|
|
|
|
|
p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
|
|
|
|
|