Fix bottleneck attn transpose typo, hopefully these train better now..

pull/880/head
Ross Wightman 3 years ago
parent 80075b0b8a
commit b81e79aae9

@ -122,7 +122,7 @@ class BottleneckAttn(nn.Module):
attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W
attn_out = attn_logits.softmax(dim=-1) attn_out = attn_logits.softmax(dim=-1)
attn_out = (attn_out @ v).transpose(1, 2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
attn_out = self.pool(attn_out) attn_out = self.pool(attn_out)
return attn_out return attn_out

Loading…
Cancel
Save