Add 'head norm first' convnext_tiny_hnf weights

pull/1190/head
Ross Wightman 3 years ago
parent dc51334cdc
commit 474ac906a2

@ -44,7 +44,9 @@ default_cfgs = dict(
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
convnext_nano_hnf=_cfg(url=''), convnext_nano_hnf=_cfg(url=''),
convnext_tiny_hnf=_cfg(url=''), convnext_tiny_hnf=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
crop_pct=0.95),
convnext_base_in22ft1k=_cfg( convnext_base_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'), url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'),
@ -322,6 +324,8 @@ def _init_weights(module, name=None, head_init_scale=1.0):
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap FB checkpoints -> timm """ """ Remap FB checkpoints -> timm """
if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
return state_dict # non-FB checkpoint
if 'model' in state_dict: if 'model' in state_dict:
state_dict = state_dict['model'] state_dict = state_dict['model']
out_dict = {} out_dict = {}

Loading…
Cancel
Save