From 11ae795e99f6146dfada86eeef8dcd8d1dcb8679 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 May 2021 10:15:32 -0700 Subject: [PATCH] Redo LeViT attention bias caching in a way that works with both torchscript and DataParallel --- timm/models/levit.py | 49 +++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/timm/models/levit.py b/timm/models/levit.py index 96a0c85b..5019ee9a 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -26,6 +26,7 @@ Modifications by/coyright Copyright 2021 Ross Wightman import itertools from copy import deepcopy from functools import partial +from typing import Dict import torch import torch.nn as nn @@ -255,6 +256,8 @@ class Subsample(nn.Module): class Attention(nn.Module): + ab: Dict[str, torch.Tensor] + def __init__( self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): super().__init__() @@ -286,20 +289,31 @@ class Attention(nn.Module): idxs.append(attention_offsets[offset]) self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) - self.ab = None + self.ab = {} @torch.no_grad() def train(self, mode=True): super().train(mode) - self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] def forward(self, x): # x (B,C,H,W) if self.use_conv: B, C, H, W = x.shape q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab - attn = (q.transpose(-2, -1) @ k) * self.scale + ab + + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) else: B, N, C = x.shape @@ -308,15 +322,18 @@ class Attention(nn.Module): q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab - attn = q @ k.transpose(-2, -1) * self.scale + ab + + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x class AttentionSubsample(nn.Module): + ab: Dict[str, torch.Tensor] + def __init__( self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False): @@ -366,12 +383,22 @@ class AttentionSubsample(nn.Module): idxs.append(attention_offsets[offset]) self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) - self.ab = None + self.ab = {} # per-device attention_biases cache @torch.no_grad() def train(self, mode=True): super().train(mode) - self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] def forward(self, x): if self.use_conv: @@ -379,8 +406,7 @@ class AttentionSubsample(nn.Module): k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab - attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) @@ -391,8 +417,7 @@ class AttentionSubsample(nn.Module): v = v.permute(0, 2, 1, 3) # BHNC q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab - attn = q @ k.transpose(-2, -1) * self.scale + ab + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)