From c0211b0bf79ee7e1009d04f11d27a061caa670b6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 12 May 2022 22:31:55 -0700 Subject: [PATCH] Swin-V2 test fixes, typo --- tests/test_models.py | 2 +- timm/models/swin_transformer_v2.py | 7 +++++-- timm/models/swin_transformer_v2_cr.py | 14 +++++++------- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 6489892c..7ea9af6e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -25,7 +25,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', - 'poolformer_*', 'volo_*', 'sequencer2d_*'] + 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 29c0be9e..700012fe 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -39,7 +39,7 @@ def _cfg(url='', **kwargs): default_cfgs = { - 'swinv2_tiny_window8_256.': _cfg( + 'swinv2_tiny_window8_256': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth', input_size=(3, 256, 256) ), @@ -106,6 +106,7 @@ def window_partition(x, window_size): return windows +@register_notrace_function # reason: int argument is a Proxy def window_reverse(windows, window_size, H, W): """ Args: @@ -190,9 +191,11 @@ class WindowAttention(nn.Module): self.qkv = nn.Linear(dim, dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(dim)) + self.register_buffer('k_bias', torch.zeros(dim), persistent=False) self.v_bias = nn.Parameter(torch.zeros(dim)) else: self.q_bias = None + self.k_bias = None self.v_bias = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) @@ -208,7 +211,7 @@ class WindowAttention(nn.Module): B_, N, C = x.shape qkv_bias = None if self.q_bias is not None: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 596ee204..fcfa217e 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -51,7 +51,7 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), - 'pool_size': None, + 'pool_size': (7, 7), 'crop_pct': 0.9, 'interpolation': 'bicubic', 'fixed_input_size': True, @@ -65,14 +65,14 @@ def _cfg(url='', **kwargs): default_cfgs = { 'swinv2_cr_tiny_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), 'swinv2_cr_tiny_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_tiny_ns_224': _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_small_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), 'swinv2_cr_small_224': _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth", input_size=(3, 224, 224), crop_pct=0.9), @@ -80,21 +80,21 @@ default_cfgs = { url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_base_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), 'swinv2_cr_base_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_base_ns_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_large_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), 'swinv2_cr_large_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_huge_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), 'swinv2_cr_huge_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_giant_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), 'swinv2_cr_giant_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), }