Fix torchscript issue in bat

pull/693/head
Ross Wightman 3 years ago
parent d17b374f0f
commit b7a568f065

@ -81,7 +81,7 @@ class BilinearAttnTransform(nn.Module):
self.groups = groups
self.in_channels = in_channels
def resize_mat(self, x, t):
def resize_mat(self, x, t: int):
B, C, block_size, block_size1 = x.shape
assert block_size == block_size1
if t <= 1:
@ -100,10 +100,8 @@ class BilinearAttnTransform(nn.Module):
out = self.conv1(x)
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
cp = F.adaptive_max_pool2d(out, (1, self.block_size))
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size)
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size)
p = F.sigmoid(p)
q = F.sigmoid(q)
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
p = p / p.sum(dim=3, keepdim=True)
q = q / q.sum(dim=2, keepdim=True)
p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(

Loading…
Cancel
Save