From b7a568f06504310381733cba0cf8cc54a557442c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 8 Jun 2021 23:19:51 -0700 Subject: [PATCH] Fix torchscript issue in bat --- timm/models/layers/non_local_attn.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index d20a5f3e..a537d60e 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -81,7 +81,7 @@ class BilinearAttnTransform(nn.Module): self.groups = groups self.in_channels = in_channels - def resize_mat(self, x, t): + def resize_mat(self, x, t: int): B, C, block_size, block_size1 = x.shape assert block_size == block_size1 if t <= 1: @@ -100,10 +100,8 @@ class BilinearAttnTransform(nn.Module): out = self.conv1(x) rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) cp = F.adaptive_max_pool2d(out, (1, self.block_size)) - p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size) - q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size) - p = F.sigmoid(p) - q = F.sigmoid(q) + p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid() + q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid() p = p / p.sum(dim=3, keepdim=True) q = q / q.sum(dim=2, keepdim=True) p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(