Add ideas from 'Scaling ViT to 22-B Params', testing PyTorch 2.0 fused F.scaled_dot_product_attention impl in vit, vit_relpos, maxxvit / coatnet.

pull/1680/head
Ross Wightman 1 year ago committed by Ross Wightman
parent a3d528524a
commit 621e1b2182

@ -28,7 +28,7 @@ from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same

@ -17,6 +17,12 @@ try:
except ImportError:
has_apex = False
try:
from apex.normalization.fused_layer_norm import fused_rms_norm_affine, fused_rms_norm
has_apex_rmsnorm = True
except ImportError:
has_apex_rmsnorm = False
# fast (ie lower precision LN) can be disabled with this flag if issues crop up
_USE_FAST_NORM = False # defaulting to False for now
@ -76,3 +82,32 @@ def fast_layer_norm(
with torch.cuda.amp.autocast(enabled=False):
return F.layer_norm(x, normalized_shape, weight, bias, eps)
def rms_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
):
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
v = torch.var(x, dim=dims, keepdim=True)
x = x * torch.rsqrt(v + eps)
if weight is not None:
x = x * weight
return x
def fast_rms_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
) -> torch.Tensor:
if torch.jit.is_scripting() or not has_apex_rmsnorm:
return rms_norm(x, normalized_shape, weight, eps)
if weight is None:
return fused_rms_norm(x, normalized_shape, eps)
else:
return fused_rms_norm_affine(x, weight, normalized_shape, eps)

@ -4,12 +4,14 @@ Norm layer definitions that support fast norm and consistent channel arg order (
Hacked together by / Copyright 2022 Ross Wightman
"""
import numbers
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm
class GroupNorm(nn.GroupNorm):
@ -115,3 +117,38 @@ class LayerNormExp2d(nn.LayerNorm):
else:
x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
return x
class RmsNorm(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
normalized_shape = channels
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
return x

@ -83,8 +83,8 @@ def gen_relative_log_coords(
pretrained_win_size: Tuple[int, int] = (0, 0),
mode='swin',
):
assert mode in ('swin', 'cr', 'rw')
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
assert mode in ('swin', 'cr')
# as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
@ -100,18 +100,9 @@ def gen_relative_log_coords(
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / math.log2(8)
else:
if mode == 'rw':
# cr w/ window size normalization -> [-1,1] log coords
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # scale to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs())
relative_coords_table /= math.log2(9) # -> [-1, 1]
else:
# mode == 'cr'
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
1.0 + relative_coords_table.abs())
# mode == 'cr'
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
1.0 + relative_coords_table.abs())
return relative_coords_table
@ -141,10 +132,6 @@ class RelPosMlp(nn.Module):
self.bias_act = nn.Sigmoid()
self.bias_gain = 16
mlp_bias = (True, False)
elif mode == 'rw':
self.bias_act = nn.Tanh()
self.bias_gain = 4
mlp_bias = True
else:
self.bias_act = nn.Identity()
self.bias_gain = None

@ -160,6 +160,7 @@ class Attention2d(nn.Module):
self.dim_head = dim_head
self.head_first = head_first
self.scale = dim_head ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
@ -175,15 +176,31 @@ class Attention2d(nn.Module):
else:
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
attn = (q.transpose(-2, -1) @ k) * self.scale
if self.rel_pos is not None:
attn = self.rel_pos(attn)
elif shared_rel_pos is not None:
attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if self.fast_attn:
if self.rel_pos is not None:
attn_bias = self.rel_pos.get_bias()
elif shared_rel_pos is not None:
attn_bias = shared_rel_pos
else:
attn_bias = None
x = torch.nn.functional.scaled_dot_product_attention(
q.transpose(-1, -2),
k.transpose(-1, -2),
v.transpose(-1, -2),
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
).transpose(-1, -2).reshape(B, -1, H, W)
else:
q = q * self.scale
attn = q.transpose(-2, -1) @ k
if self.rel_pos is not None:
attn = self.rel_pos(attn)
elif shared_rel_pos is not None:
attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
x = self.proj(x)
x = self.proj_drop(x)
return x
@ -211,6 +228,7 @@ class AttentionCl(nn.Module):
self.dim_head = dim_head
self.head_first = head_first
self.scale = dim_head ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias)
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
@ -227,15 +245,30 @@ class AttentionCl(nn.Module):
else:
q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2)
attn = (q @ k.transpose(-2, -1)) * self.scale
if self.rel_pos is not None:
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
elif shared_rel_pos is not None:
attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,))
if self.fast_attn:
if self.rel_pos is not None:
attn_bias = self.rel_pos.get_bias()
elif shared_rel_pos is not None:
attn_bias = shared_rel_pos
else:
attn_bias = None
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.rel_pos is not None:
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
elif shared_rel_pos is not None:
attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(restore_shape + (-1,))
x = self.proj(x)
x = self.proj_drop(x)
return x

@ -37,7 +37,7 @@ import torch.utils.checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
resample_abs_pos_embed
resample_abs_pos_embed, RmsNorm
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._pretrained import generate_default_cfgs
@ -51,28 +51,49 @@ _logger = logging.getLogger(__name__)
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if self.fast_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@ -96,6 +117,7 @@ class Block(nn.Module):
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
drop=0.,
attn_drop=0.,
init_values=None,
@ -105,13 +127,25 @@ class Block(nn.Module):
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -129,6 +163,7 @@ class ResPostBlock(nn.Module):
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
drop=0.,
attn_drop=0.,
init_values=None,
@ -139,11 +174,24 @@ class ResPostBlock(nn.Module):
super().__init__()
self.init_values = init_values
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -161,8 +209,61 @@ class ResPostBlock(nn.Module):
return x
class ParallelBlock(nn.Module):
class ParallelScalingBlock(nn.Module):
""" Parallel ViT block (MLP & Attention in parallel)
Based on:
'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
"""
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
y1 = self.drop_path1(self.ls1(self.attn(self.norm1(x))))
y2 = self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
x = x + y1 + y2
return x
class ParallelThingsBlock(nn.Module):
""" Parallel ViT block (N parallel attention followed by N parallel MLP)
Based on:
`Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
"""
def __init__(
self,
dim,
@ -170,6 +271,7 @@ class ParallelBlock(nn.Module):
num_parallel=2,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
init_values=None,
drop=0.,
attn_drop=0.,
@ -184,13 +286,26 @@ class ParallelBlock(nn.Module):
for _ in range(num_parallel):
self.attns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)),
('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)),
('attn', Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)),
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
])))
self.ffns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)),
('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)),
('mlp', Mlp(
dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)),
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
])))
@ -232,6 +347,7 @@ class VisionTransformer(nn.Module):
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
qk_norm=False,
init_values=None,
class_token=True,
no_embed_class=False,
@ -305,6 +421,7 @@ class VisionTransformer(nn.Module):
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
@ -641,9 +758,8 @@ def checkpoint_filter_fn(
""" convert patch embedding weight from manual patchify + linear proj to conv"""
import re
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
if 'visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model)
@ -1129,6 +1245,9 @@ default_cfgs = generate_default_cfgs({
url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
'vit_large_patch14_xp_224.untrained': _cfg(url=''),
'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
})
@ -1566,7 +1685,7 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock)
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelThingsBlock)
model = _create_vision_transformer(
'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1577,7 +1696,8 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock)
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelThingsBlock)
model = _create_vision_transformer(
'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1625,3 +1745,29 @@ def flexivit_large(pretrained=False, **kwargs):
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def vit_large_patch14_xp_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, no_embed_class=True,
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
)
model = _create_vision_transformer(
'vit_large_patch14_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def vit_huge_patch14_xp_224(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, no_embed_class=True,
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
)
model = _create_vision_transformer(
'vit_huge_patch14_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model

@ -25,14 +25,27 @@ _logger = logging.getLogger(__name__)
class RelPosAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
rel_pos_cls=None,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
@ -40,18 +53,35 @@ class RelPosAttention(nn.Module):
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if self.rel_pos is not None:
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
elif shared_rel_pos is not None:
attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q = self.q_norm(q)
k = self.k_norm(k)
if self.fast_attn:
if self.rel_pos is not None:
attn_bias = self.rel_pos.get_bias()
elif shared_rel_pos is not None:
attn_bias = shared_rel_pos
else:
attn_bias = None
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.rel_pos is not None:
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
elif shared_rel_pos is not None:
attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@ -70,18 +100,42 @@ class LayerScale(nn.Module):
class RelPosBlock(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
rel_pos_cls=None,
init_values=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = RelPosAttention(
dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop)
dim,
num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
rel_pos_cls=rel_pos_cls,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -94,17 +148,41 @@ class RelPosBlock(nn.Module):
class ResPostRelPosBlock(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
rel_pos_cls=None,
init_values=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.init_values = init_values
self.attn = RelPosAttention(
dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop)
dim,
num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
rel_pos_cls=rel_pos_cls,
attn_drop=attn_drop,
proj_drop=drop,
)
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -144,6 +222,7 @@ class VisionTransformerRelPos(nn.Module):
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
qk_norm=False,
init_values=1e-6,
class_token=False,
fc_norm=False,
@ -171,6 +250,7 @@ class VisionTransformerRelPos(nn.Module):
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_norm (bool): Enable normalization of query and key in attention
init_values: (float): layer-scale init values
class_token (bool): use class token (default: False)
fc_norm (bool): use pre classifier norm instead of pre-pool
@ -197,18 +277,19 @@ class VisionTransformerRelPos(nn.Module):
self.grad_checkpointing = False
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
feat_size = self.patch_embed.grid_size
rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens)
if rel_pos_type.startswith('mlp'):
if rel_pos_dim:
rel_pos_args['hidden_dim'] = rel_pos_dim
# FIXME experimenting with different relpos log coord configs
if 'swin' in rel_pos_type:
rel_pos_args['mode'] = 'swin'
elif 'rw' in rel_pos_type:
rel_pos_args['mode'] = 'rw'
rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
else:
rel_pos_cls = partial(RelPosBias, **rel_pos_args)
@ -223,9 +304,19 @@ class VisionTransformerRelPos(nn.Module):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
block_fn(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls,
init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
norm_layer=norm_layer, act_layer=act_layer)
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
rel_pos_cls=rel_pos_cls,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
)
for i in range(depth)])
self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()

Loading…
Cancel
Save