diff --git a/tests/test_models.py b/tests/test_models.py index 7a3f143e..39e2dcdc 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -407,7 +407,6 @@ def test_model_backward_fx(model_name, batch_size): # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow EXCLUDE_FX_JIT_FILTERS = [ - 'beit_*', 'deit_*_distilled_patch16_224', 'levit*', 'pit_*_distilled_224', diff --git a/timm/models/beit.py b/timm/models/beit.py index 199c2a4b..f644b657 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -86,9 +86,11 @@ class Attention(nn.Module): self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None + self.k_bias = None self.v_bias = None if window_size: @@ -127,13 +129,7 @@ class Attention(nn.Module): def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None): B, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - if torch.jit.is_scripting(): - # FIXME requires_grad breaks w/ torchscript - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias), self.v_bias)) - else: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)