Redo LeViT attention bias caching in a way that works with both torchscript and DataParallel

pull/637/head
Ross Wightman 4 years ago
parent d400f1dbdd
commit 11ae795e99

@ -26,6 +26,7 @@ Modifications by/coyright Copyright 2021 Ross Wightman
import itertools
from copy import deepcopy
from functools import partial
from typing import Dict
import torch
import torch.nn as nn
@ -255,6 +256,8 @@ class Subsample(nn.Module):
class Attention(nn.Module):
ab: Dict[str, torch.Tensor]
def __init__(
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
super().__init__()
@ -286,20 +289,31 @@ class Attention(nn.Module):
idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
self.ab = None
self.ab = {}
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs]
if mode and self.ab:
self.ab = {} # clear ab cache
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.ab:
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.ab[device_key]
def forward(self, x): # x (B,C,H,W)
if self.use_conv:
B, C, H, W = x.shape
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
else:
B, N, C = x.shape
@ -308,15 +322,18 @@ class Attention(nn.Module):
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
class AttentionSubsample(nn.Module):
ab: Dict[str, torch.Tensor]
def __init__(
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False):
@ -366,12 +383,22 @@ class AttentionSubsample(nn.Module):
idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
self.ab = None
self.ab = {} # per-device attention_biases cache
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs]
if mode and self.ab:
self.ab = {} # clear ab cache
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.ab:
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.ab[device_key]
def forward(self, x):
if self.use_conv:
@ -379,8 +406,7 @@ class AttentionSubsample(nn.Module):
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
@ -391,8 +417,7 @@ class AttentionSubsample(nn.Module):
v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)

Loading…
Cancel
Save