Update vit_relpos w/ some additional weights, some cleanup to match recent vit updates, more MLP log coord experiments.

pull/1327/head
Ross Wightman 2 years ago
parent 58621723bd
commit ce65a7b29f

@ -8,6 +8,7 @@ import math
import logging import logging
from functools import partial from functools import partial
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
@ -16,7 +17,7 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
from .registry import register_model from .registry import register_model
@ -47,9 +48,16 @@ default_cfgs = {
'vit_relpos_base_patch16_224': _cfg( 'vit_relpos_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'),
'vit_srelpos_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth'),
'vit_srelpos_medium_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth'),
'vit_relpos_medium_patch16_cls_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth'),
'vit_relpos_base_patch16_cls_224': _cfg( 'vit_relpos_base_patch16_cls_224': _cfg(
url=''), url=''),
'vit_relpos_base_patch16_gapcls_224': _cfg( 'vit_relpos_base_patch16_clsgap_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'),
'vit_relpos_small_patch16_rpn_224': _cfg(url=''), 'vit_relpos_small_patch16_rpn_224': _cfg(url=''),
@ -59,35 +67,43 @@ default_cfgs = {
} }
def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor: def gen_relative_position_index(
# cut and paste w/ modifications from swin / beit codebase q_size: Tuple[int, int],
# cls to token & token 2 cls & cls to cls k_size: Tuple[int, int] = None,
class_token: bool = False) -> torch.Tensor:
# Adapted with significant modifications from Swin / BeiT codebases
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
window_area = win_size[0] * win_size[1] q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww if k_size is None:
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww k_coords = q_coords
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 k_size = q_size
relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0 else:
relative_coords[:, :, 1] += win_size[1] - 1 # different q vs k sizes is a WIP
relative_coords[:, :, 0] *= 2 * win_size[1] - 1 k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
if class_token: if class_token:
num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3 # handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) # NOTE not intended or tested with MLP log-coords
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
relative_position_index[0, 0:] = num_relative_distance - 3 relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2 relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1 relative_position_index[0, 0] = num_relative_distance - 1
else:
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww return relative_position_index.contiguous()
return relative_position_index
def gen_relative_log_coords( def gen_relative_log_coords(
win_size: Tuple[int, int], win_size: Tuple[int, int],
pretrained_win_size: Tuple[int, int] = (0, 0), pretrained_win_size: Tuple[int, int] = (0, 0),
mode='swin' mode='swin',
): ):
# as per official swin-v2 impl, supporting timm swin-v2-cr coords as well assert mode in ('swin', 'cr', 'rw')
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
@ -100,12 +116,22 @@ def gen_relative_log_coords(
relative_coords_table[:, :, 0] /= (win_size[0] - 1) relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1) relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8 relative_coords_table *= 8 # normalize to -8, 8
scale = math.log2(8) relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / math.log2(8)
else: else:
# FIXME we should support a form of normalization (to -1/1) for this mode? if mode == 'rw':
scale = math.log2(math.e) # cr w/ window size normalization -> [-1,1] log coords
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # scale to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2( relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / scale 1.0 + relative_coords_table.abs())
relative_coords_table /= math.log2(9) # -> [-1, 1]
else:
# mode == 'cr'
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
1.0 + relative_coords_table.abs())
return relative_coords_table return relative_coords_table
@ -115,19 +141,29 @@ class RelPosMlp(nn.Module):
window_size, window_size,
num_heads=8, num_heads=8,
hidden_dim=128, hidden_dim=128,
class_token=False, prefix_tokens=0,
mode='cr', mode='cr',
pretrained_window_size=(0, 0) pretrained_window_size=(0, 0)
): ):
super().__init__() super().__init__()
self.window_size = window_size self.window_size = window_size
self.window_area = self.window_size[0] * self.window_size[1] self.window_area = self.window_size[0] * self.window_size[1]
self.class_token = 1 if class_token else 0 self.prefix_tokens = prefix_tokens
self.num_heads = num_heads self.num_heads = num_heads
self.bias_shape = (self.window_area,) * 2 + (num_heads,) self.bias_shape = (self.window_area,) * 2 + (num_heads,)
self.apply_sigmoid = mode == 'swin' if mode == 'swin':
self.bias_act = nn.Sigmoid()
self.bias_gain = 16
mlp_bias = (True, False)
elif mode == 'rw':
self.bias_act = nn.Tanh()
self.bias_gain = 4
mlp_bias = True
else:
self.bias_act = nn.Identity()
self.bias_gain = None
mlp_bias = True
mlp_bias = (True, False) if mode == 'swin' else True
self.mlp = Mlp( self.mlp = Mlp(
2, # x, y 2, # x, y
hidden_features=hidden_dim, hidden_features=hidden_dim,
@ -155,10 +191,11 @@ class RelPosMlp(nn.Module):
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.view(self.bias_shape) relative_position_bias = relative_position_bias.view(self.bias_shape)
relative_position_bias = relative_position_bias.permute(2, 0, 1) relative_position_bias = relative_position_bias.permute(2, 0, 1)
if self.apply_sigmoid: relative_position_bias = self.bias_act(relative_position_bias)
relative_position_bias = 16 * torch.sigmoid(relative_position_bias) if self.bias_gain is not None:
if self.class_token: relative_position_bias = self.bias_gain * relative_position_bias
relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0]) if self.prefix_tokens:
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
return relative_position_bias.unsqueeze(0).contiguous() return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
@ -167,18 +204,18 @@ class RelPosMlp(nn.Module):
class RelPosBias(nn.Module): class RelPosBias(nn.Module):
def __init__(self, window_size, num_heads, class_token=False): def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__() super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size self.window_size = window_size
self.window_area = window_size[0] * window_size[1] self.window_area = window_size[0] * window_size[1]
self.class_token = 1 if class_token else 0 self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,)
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
self.register_buffer( self.register_buffer(
"relative_position_index", "relative_position_index",
gen_relative_position_index(self.window_size, class_token=self.class_token), gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
persistent=False, persistent=False,
) )
@ -306,11 +343,32 @@ class VisionTransformerRelPos(nn.Module):
""" """
def __init__( def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', self,
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6, img_size=224,
class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None, patch_size=16,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', in_chans=3,
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock): num_classes=1000,
global_pool='avg',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=1e-6,
class_token=False,
fc_norm=False,
rel_pos_type='mlp',
rel_pos_dim=None,
shared_rel_pos=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='skip',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=RelPosBlock
):
""" """
Args: Args:
img_size (int, tuple): input image size img_size (int, tuple): input image size
@ -345,19 +403,22 @@ class VisionTransformerRelPos(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1 if class_token else 0 self.num_prefix_tokens = 1 if class_token else 0
self.grad_checkpointing = False self.grad_checkpointing = False
self.patch_embed = embed_layer( self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
feat_size = self.patch_embed.grid_size feat_size = self.patch_embed.grid_size
rel_pos_args = dict(window_size=feat_size, class_token=class_token) rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens)
if rel_pos_type.startswith('mlp'): if rel_pos_type.startswith('mlp'):
if rel_pos_dim: if rel_pos_dim:
rel_pos_args['hidden_dim'] = rel_pos_dim rel_pos_args['hidden_dim'] = rel_pos_dim
# FIXME experimenting with different relpos log coord configs
if 'swin' in rel_pos_type: if 'swin' in rel_pos_type:
rel_pos_args['mode'] = 'swin' rel_pos_args['mode'] = 'swin'
elif 'rw' in rel_pos_type:
rel_pos_args['mode'] = 'rw'
rel_pos_cls = partial(RelPosMlp, **rel_pos_args) rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
else: else:
rel_pos_cls = partial(RelPosBias, **rel_pos_args) rel_pos_cls = partial(RelPosBias, **rel_pos_args)
@ -367,7 +428,7 @@ class VisionTransformerRelPos(nn.Module):
# NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both... # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both...
rel_pos_cls = None rel_pos_cls = None
self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None self.cls_token = nn.Parameter(torch.zeros(1, self.num_prefix_tokens, embed_dim)) if class_token else None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
@ -434,7 +495,7 @@ class VisionTransformerRelPos(nn.Module):
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
if self.global_pool: if self.global_pool:
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x) x = self.fc_norm(x)
return x if pre_logits else self.head(x) return x if pre_logits else self.head(x)
@ -502,6 +563,41 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_srelpos_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=False,
rel_pos_dim=384, shared_rel_pos=True, **kwargs)
model = _create_vision_transformer_relpos('vit_srelpos_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_srelpos_medium_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False,
rel_pos_dim=512, shared_rel_pos=True, **kwargs)
model = _create_vision_transformer_relpos(
'vit_srelpos_medium_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_medium_patch16_cls_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-M/16) w/ relative log-coord position, class token present
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False,
rel_pos_dim=256, class_token=True, global_pool='token', **kwargs)
model = _create_vision_transformer_relpos(
'vit_relpos_medium_patch16_cls_224', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs): def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, class token present """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
@ -514,14 +610,14 @@ def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs):
@register_model @register_model
def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs): def vit_relpos_base_patch16_clsgap_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, class token present """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled
Leaving here for comparisons w/ a future re-train as it performs quite well. Leaving here for comparisons w/ a future re-train as it performs quite well.
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs) patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer_relpos('vit_relpos_base_patch16_clsgap_224', pretrained=pretrained, **model_kwargs)
return model return model

Loading…
Cancel
Save