Remove persistent buffers from Swin-V2. Change SwinV2Cr cos attn + tau/logit_scale to match official, add ckpt convert, init_value zeros resid LN weight by default

pull/1259/head
Ross Wightman 3 years ago
parent 27c42f0830
commit d4c0588012

@ -25,7 +25,6 @@ from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
from .registry import register_model from .registry import register_model
from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -75,7 +74,7 @@ default_cfgs = {
), ),
'swinv2_base_window12to24_192to384_22kft1k': _cfg( 'swinv2_base_window12to24_192to384_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth', url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384) input_size=(3, 384, 384), crop_pct=1.0,
), ),
'swinv2_large_window12_192_22k': _cfg( 'swinv2_large_window12_192_22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
@ -87,7 +86,7 @@ default_cfgs = {
), ),
'swinv2_large_window12to24_192to384_22kft1k': _cfg( 'swinv2_large_window12to24_192to384_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth', url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384) input_size=(3, 384, 384), crop_pct=1.0,
), ),
} }
@ -174,7 +173,7 @@ class WindowAttention(nn.Module):
relative_coords_table = torch.sign(relative_coords_table) * torch.log2( relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
torch.abs(relative_coords_table) + 1.0) / math.log2(8) torch.abs(relative_coords_table) + 1.0) / math.log2(8)
self.register_buffer("relative_coords_table", relative_coords_table) self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0]) coords_h = torch.arange(self.window_size[0])
@ -187,7 +186,7 @@ class WindowAttention(nn.Module):
relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index) self.register_buffer("relative_position_index", relative_position_index, persistent=False)
self.qkv = nn.Linear(dim, dim * 3, bias=False) self.qkv = nn.Linear(dim, dim * 3, bias=False)
if qkv_bias: if qkv_bias:
@ -215,7 +214,7 @@ class WindowAttention(nn.Module):
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0)
# cosine attention # cosine attention
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
@ -559,9 +558,6 @@ class SwinTransformerV2(nn.Module):
trunc_normal_(m.weight, std=.02) trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None: if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
@ -621,6 +617,18 @@ class SwinTransformerV2(nn.Module):
return x return x
def checkpoint_filter_fn(state_dict, model):
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
for k, v in state_dict.items():
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):
continue # skip buffers that should not be persistent
out_dict[k] = v
return out_dict
def _create_swin_transformer_v2(variant, pretrained=False, **kwargs): def _create_swin_transformer_v2(variant, pretrained=False, **kwargs):
model = build_model_with_cfg( model = build_model_with_cfg(
SwinTransformerV2, variant, pretrained, SwinTransformerV2, variant, pretrained,

@ -34,6 +34,7 @@ from typing import Tuple, Optional, List, Union, Any, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -41,7 +42,7 @@ from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import DropPath, Mlp, to_2tuple, _assert from .layers import DropPath, Mlp, to_2tuple, _assert
from .registry import register_model from .registry import register_model
from .vision_transformer import checkpoint_filter_fn
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -186,12 +187,13 @@ class WindowMultiHeadAttention(nn.Module):
act_layer=nn.ReLU, act_layer=nn.ReLU,
drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without? drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without?
) )
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) # NOTE old checkpoints used inverse of logit_scale ('tau') following the paper, see conversion fn
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones(num_heads)))
self._make_pair_wise_relative_positions() self._make_pair_wise_relative_positions()
def _make_pair_wise_relative_positions(self) -> None: def _make_pair_wise_relative_positions(self) -> None:
"""Method initializes the pair-wise relative positions to compute the positional biases.""" """Method initializes the pair-wise relative positions to compute the positional biases."""
device = self.tau.device device = self.logit_scale.device
coordinates = torch.stack(torch.meshgrid([ coordinates = torch.stack(torch.meshgrid([
torch.arange(self.window_size[0], device=device), torch.arange(self.window_size[0], device=device),
torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1) torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1)
@ -250,10 +252,11 @@ class WindowMultiHeadAttention(nn.Module):
query, key, value = qkv.unbind(0) query, key, value = qkv.unbind(0)
# compute attention map with scaled cosine attention # compute attention map with scaled cosine attention
denom = torch.norm(query, dim=-1, keepdim=True) @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1) attn = (F.normalize(query, dim=-1) @ F.normalize(key, dim=-1).transpose(-2, -1))
attn = query @ key.transpose(-2, -1) / denom.clamp(min=1e-6) logit_scale = torch.clamp(self.logit_scale.reshape(1, self.num_heads, 1, 1), max=math.log(1. / 0.01)).exp()
attn = attn / self.tau.clamp(min=0.01).reshape(1, self.num_heads, 1, 1) attn = attn * logit_scale
attn = attn + self._relative_positional_encodings() attn = attn + self._relative_positional_encodings()
if mask is not None: if mask is not None:
# Apply mask if utilized # Apply mask if utilized
num_win: int = mask.shape[0] num_win: int = mask.shape[0]
@ -309,7 +312,7 @@ class SwinTransformerBlock(nn.Module):
window_size: Tuple[int, int], window_size: Tuple[int, int],
shift_size: Tuple[int, int] = (0, 0), shift_size: Tuple[int, int] = (0, 0),
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
init_values: float = 0, init_values: Optional[float] = 0,
drop: float = 0.0, drop: float = 0.0,
drop_attn: float = 0.0, drop_attn: float = 0.0,
drop_path: float = 0.0, drop_path: float = 0.0,
@ -323,7 +326,7 @@ class SwinTransformerBlock(nn.Module):
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
self.window_area = self.window_size[0] * self.window_size[1] self.window_area = self.window_size[0] * self.window_size[1]
self.init_values: float = init_values self.init_values: Optional[float] = init_values
# attn branch # attn branch
self.attn = WindowMultiHeadAttention( self.attn = WindowMultiHeadAttention(
@ -387,7 +390,7 @@ class SwinTransformerBlock(nn.Module):
def init_weights(self): def init_weights(self):
# extra, module specific weight init # extra, module specific weight init
if self.init_values: if self.init_values is not None:
nn.init.constant_(self.norm1.weight, self.init_values) nn.init.constant_(self.norm1.weight, self.init_values)
nn.init.constant_(self.norm2.weight, self.init_values) nn.init.constant_(self.norm2.weight, self.init_values)
@ -536,7 +539,7 @@ class SwinTransformerStage(nn.Module):
feat_size: Tuple[int, int], feat_size: Tuple[int, int],
window_size: Tuple[int, int], window_size: Tuple[int, int],
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
init_values: float = 0.0, init_values: Optional[float] = 0.0,
drop: float = 0.0, drop: float = 0.0,
drop_attn: float = 0.0, drop_attn: float = 0.0,
drop_path: Union[List[float], float] = 0.0, drop_path: Union[List[float], float] = 0.0,
@ -650,7 +653,7 @@ class SwinTransformerV2Cr(nn.Module):
depths: Tuple[int, ...] = (2, 2, 6, 2), depths: Tuple[int, ...] = (2, 2, 6, 2),
num_heads: Tuple[int, ...] = (3, 6, 12, 24), num_heads: Tuple[int, ...] = (3, 6, 12, 24),
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
init_values: float = 0.0, init_values: Optional[float] = 0.,
drop_rate: float = 0.0, drop_rate: float = 0.0,
attn_drop_rate: float = 0.0, attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0, drop_path_rate: float = 0.0,
@ -808,6 +811,21 @@ def init_weights(module: nn.Module, name: str = ''):
module.init_weights() module.init_weights()
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
for k, v in state_dict.items():
if 'tau' in k:
# convert old tau based checkpoints -> logit_scale (inverse)
v = torch.log(1 / v)
k = k.replace('tau', 'logit_scale')
out_dict[k] = v
return out_dict
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
@ -890,7 +908,6 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs):
embed_dim=96, embed_dim=96,
depths=(2, 2, 18, 2), depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24), num_heads=(3, 6, 12, 24),
init_values=1e-5,
extra_norm_stage=True, extra_norm_stage=True,
**kwargs **kwargs
) )
@ -928,7 +945,6 @@ def swinv2_cr_base_ns_224(pretrained=False, **kwargs):
embed_dim=128, embed_dim=128,
depths=(2, 2, 18, 2), depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32), num_heads=(4, 8, 16, 32),
init_values=1e-6,
extra_norm_stage=True, extra_norm_stage=True,
**kwargs **kwargs
) )

Loading…
Cancel
Save