|
|
|
@ -44,7 +44,9 @@ default_cfgs = dict(
|
|
|
|
|
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
|
|
|
|
|
|
|
|
|
|
convnext_nano_hnf=_cfg(url=''),
|
|
|
|
|
convnext_tiny_hnf=_cfg(url=''),
|
|
|
|
|
convnext_tiny_hnf=_cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
|
|
|
|
|
crop_pct=0.95),
|
|
|
|
|
|
|
|
|
|
convnext_base_in22ft1k=_cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'),
|
|
|
|
@ -322,6 +324,8 @@ def _init_weights(module, name=None, head_init_scale=1.0):
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
""" Remap FB checkpoints -> timm """
|
|
|
|
|
if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
|
|
|
|
|
return state_dict # non-FB checkpoint
|
|
|
|
|
if 'model' in state_dict:
|
|
|
|
|
state_dict = state_dict['model']
|
|
|
|
|
out_dict = {}
|
|
|
|
|