Fix some test failures, torchscript issues

pull/1415/head
Ross Wightman 2 years ago
parent 6e559e9b5f
commit f332fc2db7

@ -27,7 +27,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*']
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*']
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures

@ -26,7 +26,7 @@ from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True,
'crop_pct': .95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv1', 'classifier': 'head',

@ -209,8 +209,7 @@ class WindowAttentionGlobal(nn.Module):
def forward(self, x, q_global: Optional[torch.Tensor] = None):
B, N, C = x.shape
if self.use_global:
_assert(q_global is not None, 'q_global must be passed in global mode')
if self.use_global and q_global is not None:
_assert(x.shape[-1] == q_global.shape[-1], 'x and q_global seq lengths should be equal')
kv = self.qkv(x)

@ -286,7 +286,6 @@ class PyramidVisionTransformerV2(nn.Module):
self.num_classes = num_classes
assert global_pool in ('avg', '')
self.global_pool = global_pool
self.img_size = to_2tuple(img_size) if img_size is not None else None
self.depths = depths
num_stages = len(depths)
mlp_ratios = to_ntuple(num_stages)(mlp_ratios)
@ -324,7 +323,8 @@ class PyramidVisionTransformerV2(nn.Module):
cur += depths[i]
# classification head
self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
self.num_features = embed_dims[-1]
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)

Loading…
Cancel
Save