Redo LeViT attention bias caching in a way that works with both torchscript and DataParallel

pull/637/head
Ross Wightman 3 years ago
parent d400f1dbdd
commit 11ae795e99

@ -26,6 +26,7 @@ Modifications by/coyright Copyright 2021 Ross Wightman
import itertools import itertools
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -255,6 +256,8 @@ class Subsample(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
ab: Dict[str, torch.Tensor]
def __init__( def __init__(
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
super().__init__() super().__init__()
@ -286,20 +289,31 @@ class Attention(nn.Module):
idxs.append(attention_offsets[offset]) idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
self.ab = None self.ab = {}
@torch.no_grad() @torch.no_grad()
def train(self, mode=True): def train(self, mode=True):
super().train(mode) super().train(mode)
self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] if mode and self.ab:
self.ab = {} # clear ab cache
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.ab:
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.ab[device_key]
def forward(self, x): # x (B,C,H,W) def forward(self, x): # x (B,C,H,W)
if self.use_conv: if self.use_conv:
B, C, H, W = x.shape B, C, H, W = x.shape
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = (q.transpose(-2, -1) @ k) * self.scale + ab attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
else: else:
B, N, C = x.shape B, N, C = x.shape
@ -308,15 +322,18 @@ class Attention(nn.Module):
q = q.permute(0, 2, 1, 3) q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
attn = q @ k.transpose(-2, -1) * self.scale + ab attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x) x = self.proj(x)
return x return x
class AttentionSubsample(nn.Module): class AttentionSubsample(nn.Module):
ab: Dict[str, torch.Tensor]
def __init__( def __init__(
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False): act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False):
@ -366,12 +383,22 @@ class AttentionSubsample(nn.Module):
idxs.append(attention_offsets[offset]) idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
self.ab = None self.ab = {} # per-device attention_biases cache
@torch.no_grad() @torch.no_grad()
def train(self, mode=True): def train(self, mode=True):
super().train(mode) super().train(mode)
self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] if mode and self.ab:
self.ab = {} # clear ab cache
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.ab:
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.ab[device_key]
def forward(self, x): def forward(self, x):
if self.use_conv: if self.use_conv:
@ -379,8 +406,7 @@ class AttentionSubsample(nn.Module):
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
@ -391,8 +417,7 @@ class AttentionSubsample(nn.Module):
v = v.permute(0, 2, 1, 3) # BHNC v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
attn = q @ k.transpose(-2, -1) * self.scale + ab
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)

Loading…
Cancel
Save