From d4c0588012a9b5d9fddd13035a9682acd9db0ad7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 May 2022 10:49:11 -0700 Subject: [PATCH] Remove persistent buffers from Swin-V2. Change SwinV2Cr cos attn + tau/logit_scale to match official, add ckpt convert, init_value zeros resid LN weight by default --- timm/models/swin_transformer_v2.py | 26 +++++++++++------ timm/models/swin_transformer_v2_cr.py | 42 ++++++++++++++++++--------- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 8b4eff64..fe90144c 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -25,7 +25,6 @@ from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, named_apply from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert from .registry import register_model -from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit def _cfg(url='', **kwargs): @@ -75,7 +74,7 @@ default_cfgs = { ), 'swinv2_base_window12to24_192to384_22kft1k': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth', - input_size=(3, 384, 384) + input_size=(3, 384, 384), crop_pct=1.0, ), 'swinv2_large_window12_192_22k': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', @@ -87,7 +86,7 @@ default_cfgs = { ), 'swinv2_large_window12to24_192to384_22kft1k': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth', - input_size=(3, 384, 384) + input_size=(3, 384, 384), crop_pct=1.0, ), } @@ -174,7 +173,7 @@ class WindowAttention(nn.Module): relative_coords_table = torch.sign(relative_coords_table) * torch.log2( torch.abs(relative_coords_table) + 1.0) / math.log2(8) - self.register_buffer("relative_coords_table", relative_coords_table) + self.register_buffer("relative_coords_table", relative_coords_table, persistent=False) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) @@ -187,7 +186,7 @@ class WindowAttention(nn.Module): relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) + self.register_buffer("relative_position_index", relative_position_index, persistent=False) self.qkv = nn.Linear(dim, dim * 3, bias=False) if qkv_bias: @@ -215,7 +214,7 @@ class WindowAttention(nn.Module): qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv.unbind(0) # cosine attention attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) @@ -559,9 +558,6 @@ class SwinTransformerV2(nn.Module): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): @@ -621,6 +617,18 @@ class SwinTransformerV2(nn.Module): return x +def checkpoint_filter_fn(state_dict, model): + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if any([n in k for n in ('relative_position_index', 'relative_coords_table')]): + continue # skip buffers that should not be persistent + out_dict[k] = v + return out_dict + + def _create_swin_transformer_v2(variant, pretrained=False, **kwargs): model = build_model_with_cfg( SwinTransformerV2, variant, pretrained, diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index fcfa217e..d143c14c 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -34,6 +34,7 @@ from typing import Tuple, Optional, List, Union, Any, Type import torch import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -41,7 +42,7 @@ from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, named_apply from .layers import DropPath, Mlp, to_2tuple, _assert from .registry import register_model -from .vision_transformer import checkpoint_filter_fn + _logger = logging.getLogger(__name__) @@ -186,12 +187,13 @@ class WindowMultiHeadAttention(nn.Module): act_layer=nn.ReLU, drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without? ) - self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) + # NOTE old checkpoints used inverse of logit_scale ('tau') following the paper, see conversion fn + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones(num_heads))) self._make_pair_wise_relative_positions() def _make_pair_wise_relative_positions(self) -> None: """Method initializes the pair-wise relative positions to compute the positional biases.""" - device = self.tau.device + device = self.logit_scale.device coordinates = torch.stack(torch.meshgrid([ torch.arange(self.window_size[0], device=device), torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1) @@ -250,10 +252,11 @@ class WindowMultiHeadAttention(nn.Module): query, key, value = qkv.unbind(0) # compute attention map with scaled cosine attention - denom = torch.norm(query, dim=-1, keepdim=True) @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1) - attn = query @ key.transpose(-2, -1) / denom.clamp(min=1e-6) - attn = attn / self.tau.clamp(min=0.01).reshape(1, self.num_heads, 1, 1) + attn = (F.normalize(query, dim=-1) @ F.normalize(key, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale.reshape(1, self.num_heads, 1, 1), max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale attn = attn + self._relative_positional_encodings() + if mask is not None: # Apply mask if utilized num_win: int = mask.shape[0] @@ -309,7 +312,7 @@ class SwinTransformerBlock(nn.Module): window_size: Tuple[int, int], shift_size: Tuple[int, int] = (0, 0), mlp_ratio: float = 4.0, - init_values: float = 0, + init_values: Optional[float] = 0, drop: float = 0.0, drop_attn: float = 0.0, drop_path: float = 0.0, @@ -323,7 +326,7 @@ class SwinTransformerBlock(nn.Module): self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) self.window_area = self.window_size[0] * self.window_size[1] - self.init_values: float = init_values + self.init_values: Optional[float] = init_values # attn branch self.attn = WindowMultiHeadAttention( @@ -387,7 +390,7 @@ class SwinTransformerBlock(nn.Module): def init_weights(self): # extra, module specific weight init - if self.init_values: + if self.init_values is not None: nn.init.constant_(self.norm1.weight, self.init_values) nn.init.constant_(self.norm2.weight, self.init_values) @@ -536,7 +539,7 @@ class SwinTransformerStage(nn.Module): feat_size: Tuple[int, int], window_size: Tuple[int, int], mlp_ratio: float = 4.0, - init_values: float = 0.0, + init_values: Optional[float] = 0.0, drop: float = 0.0, drop_attn: float = 0.0, drop_path: Union[List[float], float] = 0.0, @@ -650,7 +653,7 @@ class SwinTransformerV2Cr(nn.Module): depths: Tuple[int, ...] = (2, 2, 6, 2), num_heads: Tuple[int, ...] = (3, 6, 12, 24), mlp_ratio: float = 4.0, - init_values: float = 0.0, + init_values: Optional[float] = 0., drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, @@ -808,6 +811,21 @@ def init_weights(module: nn.Module, name: str = ''): module.init_weights() +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'tau' in k: + # convert old tau based checkpoints -> logit_scale (inverse) + v = torch.log(1 / v) + k = k.replace('tau', 'logit_scale') + out_dict[k] = v + return out_dict + + def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') @@ -890,7 +908,6 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs): embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), - init_values=1e-5, extra_norm_stage=True, **kwargs ) @@ -928,7 +945,6 @@ def swinv2_cr_base_ns_224(pretrained=False, **kwargs): embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), - init_values=1e-6, extra_norm_stage=True, **kwargs )