More model and test fixes

pull/1415/head
Ross Wightman 2 years ago
parent ca52108c2b
commit 8c9696c9df

@ -27,7 +27,9 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*']
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatne?t_*', 'max?vit_*',
]
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures

@ -43,7 +43,7 @@ def _cfg(url='', **kwargs):
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
'fixed_input_size': True,
**kwargs
}
@ -106,7 +106,7 @@ class Downsample2d(nn.Module):
dim_out=None,
reduction='conv',
act_layer=nn.GELU,
norm_layer=LayerNorm2d,
norm_layer=LayerNorm2d, # NOTE in NCHW
):
super().__init__()
dim_out = dim_out or dim
@ -163,12 +163,10 @@ class Stem(nn.Module):
self,
in_chs: int = 3,
out_chs: int = 96,
act_layer: str = 'gelu',
norm_layer: str = 'layernorm2d', # NOTE norm for NCHW
act_layer: Callable = nn.GELU,
norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW
):
super().__init__()
act_layer = get_act_layer(act_layer)
norm_layer = get_norm_layer(norm_layer)
self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1)
self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer)
@ -333,15 +331,11 @@ class GlobalContextVitStage(nn.Module):
proj_drop: float = 0.,
attn_drop: float = 0.,
drop_path: Union[List[float], float] = 0.0,
act_layer: str = 'gelu',
norm_layer: str = 'layernorm2d',
norm_layer_cl: str = 'layernorm',
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
norm_layer_cl: Callable = LayerNorm2d,
):
super().__init__()
act_layer = get_act_layer(act_layer)
norm_layer = get_norm_layer(norm_layer)
norm_layer_cl = get_norm_layer(norm_layer_cl)
if downsample:
self.downsample = Downsample2d(
dim=dim,
@ -421,8 +415,13 @@ class GlobalContextVit(nn.Module):
act_layer: str = 'gelu',
norm_layer: str = 'layernorm2d',
norm_layer_cl: str = 'layernorm',
norm_eps: float = 1e-5,
):
super().__init__()
act_layer = get_act_layer(act_layer)
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
img_size = to_2tuple(img_size)
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
self.global_pool = global_pool
@ -432,7 +431,11 @@ class GlobalContextVit(nn.Module):
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
self.stem = Stem(
in_chs=in_chans, out_chs=embed_dim, act_layer=act_layer, norm_layer=norm_layer)
in_chs=in_chans,
out_chs=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer
)
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []

@ -18,6 +18,7 @@ _NORM_ACT_MAP = dict(
batchnorm=BatchNormAct2d,
batchnorm2d=BatchNormAct2d,
groupnorm=GroupNormAct,
groupnorm1=functools.partial(GroupNormAct, num_groups=1),
layernorm=LayerNormAct,
layernorm2d=LayerNormAct2d,
evonormb0=EvoNorm2dB0,
@ -72,6 +73,8 @@ def get_norm_act_layer(norm_layer, act_layer=None):
norm_act_layer = BatchNormAct2d
elif type_name.startswith('groupnorm'):
norm_act_layer = GroupNormAct
elif type_name.startswith('groupnorm1'):
norm_act_layer = functools.partial(GroupNormAct, num_groups=1)
elif type_name.startswith('layernorm2d'):
norm_act_layer = LayerNormAct2d
elif type_name.startswith('layernorm'):

@ -226,6 +226,7 @@ class LayerNormAct2d(nn.LayerNorm):
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x):
x = x.permute(0, 2, 3, 1)

@ -24,6 +24,7 @@ import torch.utils.checkpoint as checkpoint
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg
from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
from .registry import register_model
@ -35,7 +36,8 @@ def _cfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': True,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
'fixed_input_size': True,
**kwargs
}
@ -169,6 +171,7 @@ class PatchEmbed(nn.Module):
return x.flatten(2).transpose(1, 2), x.shape[-2:]
@register_notrace_function
def reshape_pre_pool(
x,
feat_size: List[int],
@ -183,6 +186,7 @@ def reshape_pre_pool(
return x, cls_tok
@register_notrace_function
def reshape_post_pool(
x,
num_heads: int,
@ -196,6 +200,7 @@ def reshape_post_pool(
return x, feat_size
@register_notrace_function
def cal_rel_pos_type(
attn: torch.Tensor,
q: torch.Tensor,

@ -36,7 +36,7 @@ def _cfg(url='', **kwargs):
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.conv', 'classifier': 'head', 'fixed_input_size': False,
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False,
**kwargs
}

Loading…
Cancel
Save