From 474ac906a24b27ac49f7088a8d12a852437c1067 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 23 Mar 2022 16:06:00 -0700 Subject: [PATCH] Add 'head norm first' convnext_tiny_hnf weights --- timm/models/convnext.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 0a2df3de..8b7b5c85 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -44,7 +44,9 @@ default_cfgs = dict( convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), convnext_nano_hnf=_cfg(url=''), - convnext_tiny_hnf=_cfg(url=''), + convnext_tiny_hnf=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', + crop_pct=0.95), convnext_base_in22ft1k=_cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'), @@ -322,6 +324,8 @@ def _init_weights(module, name=None, head_init_scale=1.0): def checkpoint_filter_fn(state_dict, model): """ Remap FB checkpoints -> timm """ + if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict: + return state_dict # non-FB checkpoint if 'model' in state_dict: state_dict = state_dict['model'] out_dict = {}