Merge pull request #1593 from rwightman/multi-weight_effnet_convnext
Update efficientnet.py and convnext.py to multi-weight, add new 12k pretrained weightspull/1612/head
commit
4e24f75289
@ -1,207 +1,52 @@
|
||||
""" Position Embedding Utilities
|
||||
|
||||
Hacked together by / Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
|
||||
def pixel_freq_bands(
|
||||
num_bands: int,
|
||||
max_freq: float = 224.,
|
||||
linear_bands: bool = True,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if linear_bands:
|
||||
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
|
||||
else:
|
||||
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
|
||||
return bands * torch.pi
|
||||
|
||||
|
||||
def inv_freq_bands(
|
||||
num_bands: int,
|
||||
temperature: float = 100000.,
|
||||
step: int = 2,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
|
||||
return inv_freq
|
||||
|
||||
|
||||
def build_sincos2d_pos_embed(
|
||||
feat_shape: List[int],
|
||||
dim: int = 64,
|
||||
temperature: float = 10000.,
|
||||
reverse_coord: bool = False,
|
||||
interleave_sin_cos: bool = False,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
feat_shape:
|
||||
dim:
|
||||
temperature:
|
||||
reverse_coord: stack grid order W, H instead of H, W
|
||||
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
|
||||
dtype:
|
||||
device:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
|
||||
pos_dim = dim // 4
|
||||
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
|
||||
|
||||
if reverse_coord:
|
||||
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
|
||||
grid = torch.stack(
|
||||
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
|
||||
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
|
||||
# FIXME add support for unflattened spatial dim?
|
||||
|
||||
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
|
||||
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
|
||||
return pos_emb
|
||||
|
||||
|
||||
def build_fourier_pos_embed(
|
||||
feat_shape: List[int],
|
||||
bands: Optional[torch.Tensor] = None,
|
||||
num_bands: int = 64,
|
||||
max_res: int = 224,
|
||||
linear_bands: bool = False,
|
||||
include_grid: bool = False,
|
||||
concat_out: bool = True,
|
||||
in_pixels: bool = True,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
if bands is None:
|
||||
if in_pixels:
|
||||
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
|
||||
else:
|
||||
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
|
||||
else:
|
||||
if device is None:
|
||||
device = bands.device
|
||||
if dtype is None:
|
||||
dtype = bands.dtype
|
||||
|
||||
if in_pixels:
|
||||
grid = torch.stack(torch.meshgrid(
|
||||
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
|
||||
else:
|
||||
grid = torch.stack(torch.meshgrid(
|
||||
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
|
||||
grid = grid.unsqueeze(-1)
|
||||
pos = grid * bands
|
||||
import torch.nn.functional as F
|
||||
|
||||
pos_sin, pos_cos = pos.sin(), pos.cos()
|
||||
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
|
||||
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
|
||||
if concat_out:
|
||||
out = torch.cat(out, dim=-1)
|
||||
return out
|
||||
from .helpers import to_2tuple
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
class FourierEmbed(nn.Module):
|
||||
|
||||
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
|
||||
super().__init__()
|
||||
self.max_res = max_res
|
||||
self.num_bands = num_bands
|
||||
self.concat_grid = concat_grid
|
||||
self.keep_spatial = keep_spatial
|
||||
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
B, C = x.shape[:2]
|
||||
feat_shape = x.shape[2:]
|
||||
emb = build_fourier_pos_embed(
|
||||
feat_shape,
|
||||
self.bands,
|
||||
include_grid=self.concat_grid,
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
|
||||
batch_expand = (B,) + (-1,) * (x.ndim - 1)
|
||||
|
||||
# FIXME support nD
|
||||
if self.keep_spatial:
|
||||
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
|
||||
else:
|
||||
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
|
||||
x = x.reshape(B, feat_shape.numel(), -1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rot(x):
|
||||
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
|
||||
|
||||
|
||||
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
|
||||
return x * cos_emb + rot(x) * sin_emb
|
||||
|
||||
|
||||
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = [x]
|
||||
return [t * cos_emb + rot(t) * sin_emb for t in x]
|
||||
|
||||
|
||||
def apply_rot_embed_split(x: torch.Tensor, emb):
|
||||
split = emb.shape[-1] // 2
|
||||
return x * emb[:, :split] + rot(x) * emb[:, split:]
|
||||
|
||||
|
||||
def build_rotary_pos_embed(
|
||||
feat_shape: List[int],
|
||||
bands: Optional[torch.Tensor] = None,
|
||||
dim: int = 64,
|
||||
max_freq: float = 224,
|
||||
linear_bands: bool = False,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
def resample_abs_pos_embed(
|
||||
posemb,
|
||||
new_size: List[int],
|
||||
old_size: Optional[List[int]] = None,
|
||||
num_prefix_tokens: int = 1,
|
||||
interpolation: str = 'bicubic',
|
||||
antialias: bool = True,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
NOTE: shape arg should include spatial dim only
|
||||
"""
|
||||
feat_shape = torch.Size(feat_shape)
|
||||
|
||||
sin_emb, cos_emb = build_fourier_pos_embed(
|
||||
feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands,
|
||||
concat_out=False, device=device, dtype=dtype)
|
||||
N = feat_shape.numel()
|
||||
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
|
||||
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
|
||||
return sin_emb, cos_emb
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
""" Rotary position embedding
|
||||
|
||||
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
|
||||
been well tested, and will likely change. It will be moved to its own file.
|
||||
# sort out sizes, assume square if old size not provided
|
||||
new_size = to_2tuple(new_size)
|
||||
new_ntok = new_size[0] * new_size[1]
|
||||
if not old_size:
|
||||
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
|
||||
old_size = to_2tuple(old_size)
|
||||
if new_size == old_size: # might not both be same container type
|
||||
return posemb
|
||||
|
||||
if num_prefix_tokens:
|
||||
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
|
||||
else:
|
||||
posemb_prefix, posemb = None, posemb
|
||||
|
||||
The following impl/resources were referenced for this impl:
|
||||
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
|
||||
* https://blog.eleuther.ai/rotary-embeddings/
|
||||
"""
|
||||
def __init__(self, dim, max_res=224, linear_bands: bool = False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
|
||||
# do the interpolation
|
||||
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
|
||||
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
||||
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
|
||||
|
||||
def get_embed(self, shape: List[int]):
|
||||
return build_rotary_pos_embed(shape, self.bands)
|
||||
if verbose:
|
||||
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
|
||||
|
||||
def forward(self, x):
|
||||
# assuming channel-first tensor where spatial dim are >= 2
|
||||
sin_emb, cos_emb = self.get_embed(x.shape[2:])
|
||||
return apply_rot_embed(x, sin_emb, cos_emb)
|
||||
# add back extra (class, etc) prefix tokens
|
||||
if posemb_prefix is not None:
|
||||
print(posemb_prefix.shape, posemb.shape)
|
||||
posemb = torch.cat([posemb_prefix, posemb], dim=1)
|
||||
return posemb
|
||||
|
@ -0,0 +1,283 @@
|
||||
""" Relative position embedding modules and functions
|
||||
|
||||
Hacked together by / Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .mlp import Mlp
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
def gen_relative_position_index(
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int] = None,
|
||||
class_token: bool = False) -> torch.Tensor:
|
||||
# Adapted with significant modifications from Swin / BeiT codebases
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
|
||||
if k_size is None:
|
||||
k_coords = q_coords
|
||||
k_size = q_size
|
||||
else:
|
||||
# different q vs k sizes is a WIP
|
||||
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
|
||||
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
|
||||
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
|
||||
|
||||
if class_token:
|
||||
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
|
||||
# NOTE not intended or tested with MLP log-coords
|
||||
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
|
||||
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
|
||||
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
|
||||
relative_position_index[0, 0:] = num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = num_relative_distance - 2
|
||||
relative_position_index[0, 0] = num_relative_distance - 1
|
||||
|
||||
return relative_position_index.contiguous()
|
||||
|
||||
|
||||
class RelPosBias(nn.Module):
|
||||
""" Relative Position Bias
|
||||
Adapted from Swin-V1 relative position bias impl, modularized.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size, num_heads, prefix_tokens=0):
|
||||
super().__init__()
|
||||
assert prefix_tokens <= 1
|
||||
self.window_size = window_size
|
||||
self.window_area = window_size[0] * window_size[1]
|
||||
self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
|
||||
|
||||
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
|
||||
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
|
||||
self.register_buffer(
|
||||
"relative_position_index",
|
||||
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def get_bias(self) -> torch.Tensor:
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
|
||||
# win_h * win_w, win_h * win_w, num_heads
|
||||
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
|
||||
return relative_position_bias.unsqueeze(0).contiguous()
|
||||
|
||||
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
||||
return attn + self.get_bias()
|
||||
|
||||
|
||||
def gen_relative_log_coords(
|
||||
win_size: Tuple[int, int],
|
||||
pretrained_win_size: Tuple[int, int] = (0, 0),
|
||||
mode='swin',
|
||||
):
|
||||
assert mode in ('swin', 'cr', 'rw')
|
||||
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
|
||||
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
|
||||
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
|
||||
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
|
||||
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
|
||||
if mode == 'swin':
|
||||
if pretrained_win_size[0] > 0:
|
||||
relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
|
||||
relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
|
||||
else:
|
||||
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
|
||||
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
|
||||
relative_coords_table *= 8 # normalize to -8, 8
|
||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||||
1.0 + relative_coords_table.abs()) / math.log2(8)
|
||||
else:
|
||||
if mode == 'rw':
|
||||
# cr w/ window size normalization -> [-1,1] log coords
|
||||
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
|
||||
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
|
||||
relative_coords_table *= 8 # scale to -8, 8
|
||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||||
1.0 + relative_coords_table.abs())
|
||||
relative_coords_table /= math.log2(9) # -> [-1, 1]
|
||||
else:
|
||||
# mode == 'cr'
|
||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
|
||||
1.0 + relative_coords_table.abs())
|
||||
|
||||
return relative_coords_table
|
||||
|
||||
|
||||
class RelPosMlp(nn.Module):
|
||||
""" Log-Coordinate Relative Position MLP
|
||||
Based on ideas presented in Swin-V2 paper (https://arxiv.org/abs/2111.09883)
|
||||
|
||||
This impl covers the 'swin' implementation as well as two timm specific modes ('cr', and 'rw')
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
window_size,
|
||||
num_heads=8,
|
||||
hidden_dim=128,
|
||||
prefix_tokens=0,
|
||||
mode='cr',
|
||||
pretrained_window_size=(0, 0)
|
||||
):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.prefix_tokens = prefix_tokens
|
||||
self.num_heads = num_heads
|
||||
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
|
||||
if mode == 'swin':
|
||||
self.bias_act = nn.Sigmoid()
|
||||
self.bias_gain = 16
|
||||
mlp_bias = (True, False)
|
||||
elif mode == 'rw':
|
||||
self.bias_act = nn.Tanh()
|
||||
self.bias_gain = 4
|
||||
mlp_bias = True
|
||||
else:
|
||||
self.bias_act = nn.Identity()
|
||||
self.bias_gain = None
|
||||
mlp_bias = True
|
||||
|
||||
self.mlp = Mlp(
|
||||
2, # x, y
|
||||
hidden_features=hidden_dim,
|
||||
out_features=num_heads,
|
||||
act_layer=nn.ReLU,
|
||||
bias=mlp_bias,
|
||||
drop=(0.125, 0.)
|
||||
)
|
||||
|
||||
self.register_buffer(
|
||||
"relative_position_index",
|
||||
gen_relative_position_index(window_size),
|
||||
persistent=False)
|
||||
|
||||
# get relative_coords_table
|
||||
self.register_buffer(
|
||||
"rel_coords_log",
|
||||
gen_relative_log_coords(window_size, pretrained_window_size, mode=mode),
|
||||
persistent=False)
|
||||
|
||||
def get_bias(self) -> torch.Tensor:
|
||||
relative_position_bias = self.mlp(self.rel_coords_log)
|
||||
if self.relative_position_index is not None:
|
||||
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
|
||||
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.view(self.bias_shape)
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1)
|
||||
relative_position_bias = self.bias_act(relative_position_bias)
|
||||
if self.bias_gain is not None:
|
||||
relative_position_bias = self.bias_gain * relative_position_bias
|
||||
if self.prefix_tokens:
|
||||
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
|
||||
return relative_position_bias.unsqueeze(0).contiguous()
|
||||
|
||||
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
||||
return attn + self.get_bias()
|
||||
|
||||
|
||||
def generate_lookup_tensor(
|
||||
length: int,
|
||||
max_relative_position: Optional[int] = None,
|
||||
):
|
||||
"""Generate a one_hot lookup tensor to reindex embeddings along one dimension.
|
||||
|
||||
Args:
|
||||
length: the length to reindex to.
|
||||
max_relative_position: the maximum relative position to consider.
|
||||
Relative position embeddings for distances above this threshold
|
||||
are zeroed out.
|
||||
Returns:
|
||||
a lookup Tensor of size [length, length, vocab_size] that satisfies
|
||||
ret[n,m,v] = 1{m - n + max_relative_position = v}.
|
||||
"""
|
||||
if max_relative_position is None:
|
||||
max_relative_position = length - 1
|
||||
# Return the cached lookup tensor, otherwise compute it and cache it.
|
||||
vocab_size = 2 * max_relative_position + 1
|
||||
ret = torch.zeros(length, length, vocab_size)
|
||||
for i in range(length):
|
||||
for x in range(length):
|
||||
v = x - i + max_relative_position
|
||||
if abs(x - i) > max_relative_position:
|
||||
continue
|
||||
ret[i, x, v] = 1
|
||||
return ret
|
||||
|
||||
|
||||
def reindex_2d_einsum_lookup(
|
||||
relative_position_tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
height_lookup: torch.Tensor,
|
||||
width_lookup: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Reindex 2d relative position bias with 2 independent einsum lookups.
|
||||
|
||||
Adapted from:
|
||||
https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
|
||||
|
||||
Args:
|
||||
relative_position_tensor: tensor of shape
|
||||
[..., vocab_height, vocab_width, ...].
|
||||
height: height to reindex to.
|
||||
width: width to reindex to.
|
||||
height_lookup: one-hot height lookup
|
||||
width_lookup: one-hot width lookup
|
||||
Returns:
|
||||
reindexed_tensor: a Tensor of shape
|
||||
[..., height * width, height * width, ...]
|
||||
"""
|
||||
reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup)
|
||||
reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup)
|
||||
area = height * width
|
||||
return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area)
|
||||
|
||||
|
||||
class RelPosBiasTf(nn.Module):
|
||||
""" Relative Position Bias Impl (Compatible with Tensorflow MaxViT models)
|
||||
Adapted from:
|
||||
https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
|
||||
"""
|
||||
def __init__(self, window_size, num_heads, prefix_tokens=0):
|
||||
super().__init__()
|
||||
assert prefix_tokens <= 1
|
||||
self.window_size = window_size
|
||||
self.window_area = window_size[0] * window_size[1]
|
||||
self.num_heads = num_heads
|
||||
|
||||
vocab_height = 2 * window_size[0] - 1
|
||||
vocab_width = 2 * window_size[1] - 1
|
||||
self.bias_shape = (self.num_heads, vocab_height, vocab_width)
|
||||
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape))
|
||||
self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False)
|
||||
self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
nn.init.normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def get_bias(self) -> torch.Tensor:
|
||||
# FIXME change to not use one-hot/einsum?
|
||||
return reindex_2d_einsum_lookup(
|
||||
self.relative_position_bias_table,
|
||||
self.window_size[0],
|
||||
self.window_size[1],
|
||||
self.height_lookup,
|
||||
self.width_lookup
|
||||
)
|
||||
|
||||
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
||||
return attn + self.get_bias()
|
@ -0,0 +1,219 @@
|
||||
""" Sin-cos, fourier, rotary position embedding modules and functions
|
||||
|
||||
Hacked together by / Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from typing import List, Tuple, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
|
||||
def pixel_freq_bands(
|
||||
num_bands: int,
|
||||
max_freq: float = 224.,
|
||||
linear_bands: bool = True,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if linear_bands:
|
||||
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
|
||||
else:
|
||||
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
|
||||
return bands * torch.pi
|
||||
|
||||
|
||||
def inv_freq_bands(
|
||||
num_bands: int,
|
||||
temperature: float = 100000.,
|
||||
step: int = 2,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
|
||||
return inv_freq
|
||||
|
||||
|
||||
def build_sincos2d_pos_embed(
|
||||
feat_shape: List[int],
|
||||
dim: int = 64,
|
||||
temperature: float = 10000.,
|
||||
reverse_coord: bool = False,
|
||||
interleave_sin_cos: bool = False,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
feat_shape:
|
||||
dim:
|
||||
temperature:
|
||||
reverse_coord: stack grid order W, H instead of H, W
|
||||
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
|
||||
dtype:
|
||||
device:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
|
||||
pos_dim = dim // 4
|
||||
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
|
||||
|
||||
if reverse_coord:
|
||||
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
|
||||
grid = torch.stack(
|
||||
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
|
||||
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
|
||||
# FIXME add support for unflattened spatial dim?
|
||||
|
||||
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
|
||||
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
|
||||
return pos_emb
|
||||
|
||||
|
||||
def build_fourier_pos_embed(
|
||||
feat_shape: List[int],
|
||||
bands: Optional[torch.Tensor] = None,
|
||||
num_bands: int = 64,
|
||||
max_res: int = 224,
|
||||
linear_bands: bool = False,
|
||||
include_grid: bool = False,
|
||||
concat_out: bool = True,
|
||||
in_pixels: bool = True,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
if bands is None:
|
||||
if in_pixels:
|
||||
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
|
||||
else:
|
||||
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
|
||||
else:
|
||||
if device is None:
|
||||
device = bands.device
|
||||
if dtype is None:
|
||||
dtype = bands.dtype
|
||||
|
||||
if in_pixels:
|
||||
grid = torch.stack(torch.meshgrid(
|
||||
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
|
||||
else:
|
||||
grid = torch.stack(torch.meshgrid(
|
||||
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
|
||||
grid = grid.unsqueeze(-1)
|
||||
pos = grid * bands
|
||||
|
||||
pos_sin, pos_cos = pos.sin(), pos.cos()
|
||||
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
|
||||
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
|
||||
if concat_out:
|
||||
out = torch.cat(out, dim=-1)
|
||||
return out
|
||||
|
||||
|
||||
class FourierEmbed(nn.Module):
|
||||
|
||||
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
|
||||
super().__init__()
|
||||
self.max_res = max_res
|
||||
self.num_bands = num_bands
|
||||
self.concat_grid = concat_grid
|
||||
self.keep_spatial = keep_spatial
|
||||
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
B, C = x.shape[:2]
|
||||
feat_shape = x.shape[2:]
|
||||
emb = build_fourier_pos_embed(
|
||||
feat_shape,
|
||||
self.bands,
|
||||
include_grid=self.concat_grid,
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
|
||||
batch_expand = (B,) + (-1,) * (x.ndim - 1)
|
||||
|
||||
# FIXME support nD
|
||||
if self.keep_spatial:
|
||||
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
|
||||
else:
|
||||
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
|
||||
x = x.reshape(B, feat_shape.numel(), -1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rot(x):
|
||||
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
|
||||
|
||||
|
||||
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
|
||||
return x * cos_emb + rot(x) * sin_emb
|
||||
|
||||
|
||||
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = [x]
|
||||
return [t * cos_emb + rot(t) * sin_emb for t in x]
|
||||
|
||||
|
||||
def apply_rot_embed_split(x: torch.Tensor, emb):
|
||||
split = emb.shape[-1] // 2
|
||||
return x * emb[:, :split] + rot(x) * emb[:, split:]
|
||||
|
||||
|
||||
def build_rotary_pos_embed(
|
||||
feat_shape: List[int],
|
||||
bands: Optional[torch.Tensor] = None,
|
||||
dim: int = 64,
|
||||
max_freq: float = 224,
|
||||
linear_bands: bool = False,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
"""
|
||||
NOTE: shape arg should include spatial dim only
|
||||
"""
|
||||
feat_shape = torch.Size(feat_shape)
|
||||
|
||||
sin_emb, cos_emb = build_fourier_pos_embed(
|
||||
feat_shape,
|
||||
bands=bands,
|
||||
num_bands=dim // 4,
|
||||
max_res=max_freq,
|
||||
linear_bands=linear_bands,
|
||||
concat_out=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
N = feat_shape.numel()
|
||||
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
|
||||
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
|
||||
return sin_emb, cos_emb
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
""" Rotary position embedding
|
||||
|
||||
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
|
||||
been well tested, and will likely change. It will be moved to its own file.
|
||||
|
||||
The following impl/resources were referenced for this impl:
|
||||
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
|
||||
* https://blog.eleuther.ai/rotary-embeddings/
|
||||
"""
|
||||
|
||||
def __init__(self, dim, max_res=224, linear_bands: bool = False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
|
||||
|
||||
def get_embed(self, shape: List[int]):
|
||||
return build_rotary_pos_embed(shape, self.bands)
|
||||
|
||||
def forward(self, x):
|
||||
# assuming channel-first tensor where spatial dim are >= 2
|
||||
sin_emb, cos_emb = self.get_embed(x.shape[2:])
|
||||
return apply_rot_embed(x, sin_emb, cos_emb)
|
File diff suppressed because it is too large
Load Diff
@ -1 +1 @@
|
||||
__version__ = '0.8.1dev0'
|
||||
__version__ = '0.8.2dev0'
|
||||
|
Loading…
Reference in new issue