Swin-V2 test fixes, typo

pull/1259/head
Ross Wightman 3 years ago
parent 9a86b900fa
commit c0211b0bf7

@ -25,7 +25,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
'poolformer_*', 'volo_*', 'sequencer2d_*']
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*']
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures

@ -39,7 +39,7 @@ def _cfg(url='', **kwargs):
default_cfgs = {
'swinv2_tiny_window8_256.': _cfg(
'swinv2_tiny_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
input_size=(3, 256, 256)
),
@ -106,6 +106,7 @@ def window_partition(x, window_size):
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size, H, W):
"""
Args:
@ -190,9 +191,11 @@ class WindowAttention(nn.Module):
self.qkv = nn.Linear(dim, dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(dim))
self.register_buffer('k_bias', torch.zeros(dim), persistent=False)
self.v_bias = nn.Parameter(torch.zeros(dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
@ -208,7 +211,7 @@ class WindowAttention(nn.Module):
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)

@ -51,7 +51,7 @@ def _cfg(url='', **kwargs):
'url': url,
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': None,
'pool_size': (7, 7),
'crop_pct': 0.9,
'interpolation': 'bicubic',
'fixed_input_size': True,
@ -65,14 +65,14 @@ def _cfg(url='', **kwargs):
default_cfgs = {
'swinv2_cr_tiny_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_tiny_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_tiny_ns_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_small_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_small_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
input_size=(3, 224, 224), crop_pct=0.9),
@ -80,21 +80,21 @@ default_cfgs = {
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_base_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_base_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_base_ns_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_large_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_large_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_huge_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_huge_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_giant_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_giant_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
}

Loading…
Cancel
Save