|
|
|
@ -26,6 +26,7 @@ Modifications by/coyright Copyright 2021 Ross Wightman
|
|
|
|
|
import itertools
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Dict
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -255,6 +256,8 @@ class Subsample(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
|
|
ab: Dict[str, torch.Tensor]
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
|
|
|
|
|
super().__init__()
|
|
|
|
@ -286,20 +289,31 @@ class Attention(nn.Module):
|
|
|
|
|
idxs.append(attention_offsets[offset])
|
|
|
|
|
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
|
|
|
|
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
|
|
|
|
|
self.ab = None
|
|
|
|
|
self.ab = {}
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def train(self, mode=True):
|
|
|
|
|
super().train(mode)
|
|
|
|
|
self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs]
|
|
|
|
|
if mode and self.ab:
|
|
|
|
|
self.ab = {} # clear ab cache
|
|
|
|
|
|
|
|
|
|
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
|
|
|
|
|
if self.training:
|
|
|
|
|
return self.attention_biases[:, self.attention_bias_idxs]
|
|
|
|
|
else:
|
|
|
|
|
device_key = str(device)
|
|
|
|
|
if device_key not in self.ab:
|
|
|
|
|
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
|
|
|
|
|
return self.ab[device_key]
|
|
|
|
|
|
|
|
|
|
def forward(self, x): # x (B,C,H,W)
|
|
|
|
|
if self.use_conv:
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
|
|
|
|
|
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
|
|
|
|
|
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
|
|
|
|
|
|
|
|
|
|
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
|
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
|
|
|
|
|
else:
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
@ -308,15 +322,18 @@ class Attention(nn.Module):
|
|
|
|
|
q = q.permute(0, 2, 1, 3)
|
|
|
|
|
k = k.permute(0, 2, 1, 3)
|
|
|
|
|
v = v.permute(0, 2, 1, 3)
|
|
|
|
|
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
|
|
|
|
|
attn = q @ k.transpose(-2, -1) * self.scale + ab
|
|
|
|
|
|
|
|
|
|
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
|
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AttentionSubsample(nn.Module):
|
|
|
|
|
ab: Dict[str, torch.Tensor]
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
|
|
|
|
|
act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False):
|
|
|
|
@ -366,12 +383,22 @@ class AttentionSubsample(nn.Module):
|
|
|
|
|
idxs.append(attention_offsets[offset])
|
|
|
|
|
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
|
|
|
|
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
|
|
|
|
|
self.ab = None
|
|
|
|
|
self.ab = {} # per-device attention_biases cache
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def train(self, mode=True):
|
|
|
|
|
super().train(mode)
|
|
|
|
|
self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs]
|
|
|
|
|
if mode and self.ab:
|
|
|
|
|
self.ab = {} # clear ab cache
|
|
|
|
|
|
|
|
|
|
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
|
|
|
|
|
if self.training:
|
|
|
|
|
return self.attention_biases[:, self.attention_bias_idxs]
|
|
|
|
|
else:
|
|
|
|
|
device_key = str(device)
|
|
|
|
|
if device_key not in self.ab:
|
|
|
|
|
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
|
|
|
|
|
return self.ab[device_key]
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
if self.use_conv:
|
|
|
|
@ -379,8 +406,7 @@ class AttentionSubsample(nn.Module):
|
|
|
|
|
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
|
|
|
|
|
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
|
|
|
|
|
|
|
|
|
|
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
|
|
|
|
|
attn = (q.transpose(-2, -1) @ k) * self.scale + ab
|
|
|
|
|
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
|
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
|
|
|
|
@ -391,8 +417,7 @@ class AttentionSubsample(nn.Module):
|
|
|
|
|
v = v.permute(0, 2, 1, 3) # BHNC
|
|
|
|
|
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
|
|
|
|
|
|
|
|
|
|
ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab
|
|
|
|
|
attn = q @ k.transpose(-2, -1) * self.scale + ab
|
|
|
|
|
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
|
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
|
|
|
|
|