Cleanup qkv_bias cat in beit model so it can be traced

pull/989/head
Ross Wightman 3 years ago
parent 1076a65df1
commit f2006b2437

@ -407,7 +407,6 @@ def test_model_backward_fx(model_name, batch_size):
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow # reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [ EXCLUDE_FX_JIT_FILTERS = [
'beit_*',
'deit_*_distilled_patch16_224', 'deit_*_distilled_patch16_224',
'levit*', 'levit*',
'pit_*_distilled_224', 'pit_*_distilled_224',

@ -86,9 +86,11 @@ class Attention(nn.Module):
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias: if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else: else:
self.q_bias = None self.q_bias = None
self.k_bias = None
self.v_bias = None self.v_bias = None
if window_size: if window_size:
@ -127,13 +129,7 @@ class Attention(nn.Module):
def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None): def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None):
B, N, C = x.shape B, N, C = x.shape
qkv_bias = None qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
if self.q_bias is not None:
if torch.jit.is_scripting():
# FIXME requires_grad breaks w/ torchscript
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias), self.v_bias))
else:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)

Loading…
Cancel
Save